CV Week: Итоговое задание¶
На лекции и семинаре мы разбирали как дистиллировать многошаговую диффузионную модель в малошагового студента, и тем самым будет работать на порядок быстрее учителя.
Один из подходов, который мы разбирали Consistency Distillation. В этом задании, мы закрепим материал, который был на лекции и семинаре и реализуем этот фреймворк, затрагивая различные нюансы.
В этом задании мы будем дистиллировать модель Stable Diffusion 1.5 (SD1.5) для генерации картинок по текстовому описанию.
Вам предстоит выполнить 8 небольших заданий, которые приведут нас к неплохой модели для генерации картинок за 4 шага, работая в органиченных условиях колаба.
# torch 2.4.1+cu124
!pip install diffusers==0.30.3 peft==0.8.2 huggingface_hub==0.23.4
Теормин¶
Диффузионные модели¶
Задан прямой диффузионный процесс, который переводит чистые картинки в шум с помощью распределения $q(\mathbf{x}_t | \mathbf{x}_0)= {N}(\mathbf{x}_t | \alpha_t \mathbf{x}_0, \sigma^2_t I)$
Таким образом, мы можем получаться зашумленные картинки по следующей формуле: $\mathbf{x}_t = \alpha_t \mathbf{x}_0 + \sigma_t \epsilon$, где $\epsilon{\sim} {N}(0, I)$ (1)
$\alpha_t, \sigma_t$ задают процесс зашумления. Здесь мы будем иметь дело с variance preserving (VP) процессом, т. е., $\alpha^2_t = 1 - \sigma^2_t$.
Диффузионная модель (ДМ) пытается решить обратную задачу: из шума порождать новые картинки. Важно, что диффузионный процесс можно описать следующим обыкновенным дифференциальным уравнением (ОДУ):
$dx = \left[ f(\mathbf{x}, t) - \frac{1}{2} \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}) \right] dt$, (2)
где $f(\mathbf{x}, t)$ известен из заданного процесса зашумления, а $\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t)$ (скор функцию) оцениваем с помощью нейросети: $s_\theta(\mathbf{x}_t, t) \approx \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t)$. Таким образом, имея оценку на $\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x})$, мы можем решить это ОДУ, стартуя со случайного шума, и получить картинку.
SD1.5 использует $\epsilon$-параметризацию, т.е., UNet пытается предсказать шум, который мы добавили на картинку по формуле (1). Оценку скор функции можно получить, пользуясь результатом, вытекающим из формулы Твидди: $s_\theta(\mathbf{x}_t, t) = - \frac{\epsilon_\theta(\mathbf{x}_t, t)} { \sigma_t}$
Чтобы решить ОДУ (2), нам нужно воспользоваться каким-то численным методом (солвером). В этом задании мы будем работать с не самым эффектным, но самым популярным солвером: DDIM, который является адаптированным методом Эйлера под диффузионный ОДУ.
Для VP процесса переход с помощью DDIM с шага $t$ на $s$ можно сделать следующим образом:
$ x_s = DDIM(\mathbf{x}_t, t, s) = \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t} \right) + \sigma_s \epsilon_\theta $
Этот переход можно интерпретировать так: получаем оценку на чистую картинку $\mathbf{x}_0$ на шаге $t$, используя $\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t}$, а потом снова зашумляем эту оценку на шаг $s$ по формуле (1), но только используем не случайный шум, а шум предсказанный моделью $\epsilon_\theta$.
Используя DDIM для SD1.5, можем получать хорошие картинки за 50 шагов.
SD1.5 - латентная ДМ, т.е. модель работает не в пиксельном пространстве, а в латентном пространстве VAE. Таким образом SD1.5 состоит из следующих компонент:
- VAE - переводит $3{\times}512{\times}512$ картинки в латенты $4{\times}64{\times}64$ и может декодировать их обратно в картинки.
- Текстовый энкодер - извлекает текстовые признаки из промпта. Эти признаки будут подаваться в диффузионную модель, чтобы дать модели информацию, что именно хотим сгенерировать
- Диффузионная модель - UNet, работающий на "латентных картинках" $4{\times}64{\times}64$.
Консистенси модели¶
Общая идея¶
Главная цель дистилляции диффузии - уменьшить количество шагов ДМ, при этом сохранив высокое качество картинок.
Консистенси модели (Consistency Models | CM) - класс моделей, где мы хотим выучить "консистенси функцию" $f_\theta(\mathbf{x}_t)$ - с любой точки $\mathbf{x}_{t}$ траектории диффузионного ОДУ (2) сразу предсказывать $\mathbf{x}_{0}$ (чистые данные) за один шаг. Если мы идеально выучим консистенси функцию, то сможем шагать из чистого шума сразу в картинку, что супер эффективно в отличии от генерации ДМ.
Отметим, что консистенси модель можно учить как независимую генеративную модель, без предобученной ДМ, и в задании 3 вам предстоит подумать, как это можно сделать.

Консистенси дистилляция (Consistency Distillation | CD) - подход, когда для обучения CM, мы используем предобученную ДМ. ДМ нам дает качественную инициализацию модели и уже обученную скор функцию, что сильно упрощает сходимость консистенси моделей.
Обучение CM¶
Главная принцип обучения консистенси моделей заключается в попытке удовлетворить self-consistency св-ву: выход CM на двух соседних точках траектории $\mathbf{x}_{t}$ и $\mathbf{x}_{t-1}$ должен совпадать по какой-то мере близости, например L2 расстояние: $\lVert f_\theta(\mathbf{x}_{t-1}) - f_\theta(\mathbf{x}_{t}) \rVert^2_2$.
Заметим, что self-consistency св-во удовлетворить очень просто без какого-либо обучения, взяв, например $f_\theta(\mathbf{x}_{t}) \equiv 0$.
Поэтому, чтобы избежать вырожденных решений, нам необходимо выставить граничное условие (boundary condition), которое будет требовать, чтобы в самой левой точке траектории около 0, модель предсказывала картинку, которую получает на вход: $f_\theta(\mathbf{x}_{\epsilon}) = \mathbf{x}_{\epsilon}$.
Важное практическое замечание: Для обеих точек траектории мы применяем одну и ту же модель $f_\theta(\cdot)$. Но выход модели на шаге ${t-1}$ является "таргетом" для выхода модели на шаге $t$ и поэтому выполнение модели для шага $t-1$ выполняется в torch.no_grad режиме.
Как получаться две соседние точки на траектории ОДУ?
Берем случайную картинку $\mathbf{x}_0$ из датасета.
Точку $\mathbf{x}_t$ получаем с помощью прямого процесса зашумления: $\mathbf{x}_t = q(\mathbf{x}_t | \mathbf{x}_0)$
Чтобы получить соседнюю точку $\mathbf{x}_{t-1}$, нам нужно сделать шаг по траектории ОДУ, используя, например, DDIM солвер.
В консистенси дистилляции, мы делаем шаг предобученной ДМ: $\mathbf{x}_{t-1} = DDIM(\epsilon_\theta(\mathbf{x}_t, t), \mathbf{x}_t, t, t-1)$
from tqdm.auto import tqdm
import csv
import os
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, LCMScheduler, UNet2DConditionModel, DDIMScheduler
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
%matplotlib inline
import matplotlib.pyplot as plt
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
0it [00:00, ?it/s]
#---------------------
# Visualization utils
#---------------------
def visualize_images(images):
assert len(images) == 4
plt.figure(figsize=(12, 3))
for i, image in enumerate(images):
plt.subplot(1, 4, i+1)
plt.imshow(image)
plt.axis('off')
plt.subplots_adjust(wspace=-0.01, hspace=-0.01)
#--------------
# Tensor utils
#--------------
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
#---------------
# Dataset utils
#---------------
class COCODataset(torch.utils.data.Dataset):
def __init__(self, root_dir, subset_name="train2014_5k", transform=None, max_cnt=None):
"""
Arguments:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.root_dir = root_dir
self.transform = transform
self.extensions = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
sample_dir = os.path.join(root_dir, subset_name)
# Collect sample paths
self.samples = sorted(
[
os.path.join(sample_dir, fname)
for fname in os.listdir(sample_dir)
if fname[-4:] in self.extensions
],
key=lambda x: x.split("/")[-1].split(".")[0],
)
self.samples = (
self.samples if max_cnt is None else self.samples[:max_cnt]
) # restrict num samples
# Collect captions
self.captions = {}
with open(
os.path.join(root_dir, f"{subset_name}.csv"), newline="\n"
) as csvfile:
spamreader = csv.reader(csvfile, delimiter=",")
for i, row in enumerate(spamreader):
if i == 0:
continue
self.captions[row[1]] = row[2]
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample_path = self.samples[idx]
sample = Image.open(sample_path).convert("RGB")
if self.transform:
sample = self.transform(sample)
return {
"image": sample,
"text": self.captions[os.path.basename(sample_path)],
"idxs": idx, }
Модель учителя (SD1.5)¶
Задание №1¶
Давайте для начала загрузим модель StableDiffusion 1.5 и сгенерируем ей картинки за 50 шагов.
Важно: для экономии памяти, загружаем все компоненты модели в FP16. Не забываем положить модель на GPU.
pipe = StableDiffusionPipeline.from_pretrained(
'sd-legacy/stable-diffusion-v1-5',
torch_dtype=torch.float16,
safety_checker=None,
).to('cuda')
# Проверяем, что все компоненты модели в FP16 и на cuda
assert pipe.unet.dtype == torch.float16 and pipe.unet.device.type == 'cuda'
assert pipe.vae.dtype == torch.float16 and pipe.vae.device.type == 'cuda'
assert pipe.text_encoder.dtype == torch.float16 and pipe.text_encoder.device.type == 'cuda'
# Заменяем дефолтный сэмплер на DDIM
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler.timesteps = pipe.scheduler.timesteps.cuda()
pipe.scheduler.alphas_cumprod = pipe.scheduler.alphas_cumprod.cuda()
# Отдельно извлечем модель учителя, которую потом будем дистиллировать
teacher_unet = pipe.unet
model_index.json: 0%| | 0.00/541 [00:00<?, ?B/s]
Fetching 13 files: 0%| | 0/13 [00:00<?, ?it/s]
model.safetensors: 0%| | 0.00/492M [00:00<?, ?B/s]
text_encoder/config.json: 0%| | 0.00/617 [00:00<?, ?B/s]
scheduler/scheduler_config.json: 0%| | 0.00/308 [00:00<?, ?B/s]
(…)ature_extractor/preprocessor_config.json: 0%| | 0.00/342 [00:00<?, ?B/s]
tokenizer/special_tokens_map.json: 0%| | 0.00/472 [00:00<?, ?B/s]
tokenizer/vocab.json: 0%| | 0.00/1.06M [00:00<?, ?B/s]
tokenizer/tokenizer_config.json: 0%| | 0.00/806 [00:00<?, ?B/s]
tokenizer/merges.txt: 0%| | 0.00/525k [00:00<?, ?B/s]
vae/config.json: 0%| | 0.00/547 [00:00<?, ?B/s]
diffusion_pytorch_model.safetensors: 0%| | 0.00/335M [00:00<?, ?B/s]
Loading pipeline components...: 0%| | 0/6 [00:00<?, ?it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
Теперь сгенерируем картинки за 50 шагов. Вам нужно написать вызов pipe и передать в него промпт, число шагов генерации, генератор случайных чисел, гайденс скейл и указать, чтобы сгенерировалось 4 картинки на промпт.
prompt = "A sad puppy with large eyes"
guidance_scale = 7.5
generator = torch.Generator('cuda').manual_seed(1)
images = pipe(
prompt=prompt,
num_inference_steps=50,
num_images_per_prompt=4,
generator=generator,
guidance_scale=guidance_scale,
).images
visualize_images(images)
0%| | 0/50 [00:00<?, ?it/s]
Давайте посмотрим, что выдаст модель за 4 шага. Все то же самое, что и выше, просто поменяем число шагов.
generator = torch.Generator('cuda').manual_seed(1)
images = pipe(
prompt=prompt,
num_inference_steps=4,
num_images_per_prompt=4,
generator=generator,
guidance_scale=guidance_scale,
).images
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
На 4 шагах картинки получаются размазанными. Давайте постараемся починить их.
Создаем датасет¶
Чтобы ДЗ было легко выполнимым на colab, мы будем учить консистенси модели на небольшой обучающей выборке из 5000 пар текст-картинка из COCO датасета. Интересное свойство консистенси моделей - они могут сходиться до адекватного качества за несколько сотен шагов. Качество все еще будет не идеальным, но фазовый переход уже должен быть заметен.
Данные можно загрузить с помощью команд в ячейке ниже. В локальной текущей директории ./ должны появиться:
- Папка train2014_5k с 5000 картинками
- Файл train2014_5k.csv с 5000 промптами
Данные парсятся корректным образом в уже реализованном классе COCODataset.
# Колаб
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
# Kaggle
!pip install PyDrive
# Загрузка
!wget https://storage.yandexcloud.net/yandex-research/train2014_5k.tar.gz
!tar -C "data" -xzf train2014_5k.tar.gz
Замечание: для более быстрого дебаггинга можете взять, например, 2500 картинок и прогнать на всей выборке только в самом конце. 2500 картинок должно быть достаточно для понимания корректно ли реализованы функции. Совсем для первичного дебаггинга можно взять еще меньше картинок.
from torchvision import transforms
transform = transforms.Compose(
[
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
lambda x: 2 * x - 1,
]
)
colab = True
if colab:
data_path = "./drive/MyDrive/cv_data"
else:
data_path = "data"
dataset = COCODataset(data_path,
subset_name="train2014_5k",
transform=transform,
# max_cnt=2500
)
assert len(dataset) == 5000 # 2500
batch_size = 8 # Рекоммендуемы размер батча на Colab
train_dataloader = torch.utils.data.DataLoader(
dataset=dataset, shuffle=True, batch_size=batch_size, drop_last=True
)
@torch.no_grad()
def prepare_batch(batch, pipe):
"""
Предобработка батча картинок и текстовых промптов.
Маппим картинки в латентное пространство VAE.
Извлекаем эмбеды промптов с помощью текстового энкодера.
Params:
Return:
latents: torch.Tensor([B, 4, 64, 64], dtype=torch.float16)
prompt_embeds: torch.Tensor([B, 77, D], dtype=torch.float16)
"""
# Токенизируем промпты
text_inputs = pipe.tokenizer(
batch['text'],
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
# Извлекаем эмбеды промптов с помощью текстового энкодера
prompt_embeds = pipe.text_encoder(text_inputs.input_ids.cuda())[0]
# Переводим картинки в латентное пространство VAE
image = batch['image'].to("cuda", dtype=torch.float16)
latents = pipe.vae.encode(image).latent_dist.sample()
latents = latents * pipe.vae.config.scaling_factor
return latents, prompt_embeds
Подготовка моделей и оптимизатора¶
Для начала создаем обучаемую модель: UNet инициализируемый весами SD1.5. Вам нужно воспользоваться классом UNet2DConditionModel и загрузить отдельно только UNet модель из SD1.5.
Отметим, что эта модель у нас будет храниться в полной точности FP32, потому что обучение параметров в FP16 может приводить к нестабильностям и низкому качеству.
unet = UNet2DConditionModel.from_pretrained(
'sd-legacy/stable-diffusion-v1-5',
subfolder='unet',
torch_dtype=torch.float32,
).to('cuda').train()
assert unet.dtype == torch.float32
assert unet.training
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: The secret `HF_TOKEN` does not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn(
unet/config.json: 0%| | 0.00/743 [00:00<?, ?B/s]
diffusion_pytorch_model.safetensors: 0%| | 0.00/3.44G [00:00<?, ?B/s]
Для экономии памяти во время обучения будем учить не параметры самой модели, а добавим в нее обучаемые LoRA адаптеры с малым числом параметров.
LoRA представляет собой маленькую добавку к весам модели, где на одну матрицу весов $W \in \mathbb{R}^{m{\times}n} $ обучаются две низкоранговые матрицы $W_A \in \mathbb{R}^{k{\times}n}$ и $W_B \in \mathbb{R}^{k{\times}m}$, где $k$ - ранг матрицы сильно меньше $m$ и $n$.
Тем самым, новая обученная матрица весов может быть представлена как $\hat{W} = W + \Delta W = W + W^T_B W_A$.
Во время инференса $\Delta W$ можно вмержить в $W$ и получить итоговую модель.
Также частая практика оставлять адаптеры как есть, чтобы была возможность для одной базовой модели учить несколько адаптеров под разные задачи и переключаться между ними по необходимости.
Если не мержить адаптеры, то вычисления для линейного слоя происходят как на картинке ниже.
# Указываем к каким слоям модели мы будет добавлять адаптеры.
lora_modules = [
"to_q", "to_k", "to_v", "to_out.0", "proj_in", "proj_out",
"ff.net.0.proj", "ff.net.2", "conv1", "conv2", "conv_shortcut",
"downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj"
]
lora_config = LoraConfig(
r=64, # задает ранг у матриц A и B в LoRA.
target_modules=lora_modules
)
# Создаем обертку исходной UNet модели с LoRA адаптерами, используя библиотеку PEFT
cm_unet = get_peft_model(unet, lora_config, adapter_name="ct")
# Включаем gradient checkpointing - важная техника для экономии памяти во время обучения
cm_unet.enable_gradient_checkpointing()
# Создаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
# Задаем лосс функцию для CM обжектива. В базовом варианте разумно взять L2
# По умолчанию, она уже выдает усредненное значение по всем размерностям
mse_loss = torch.nn.functional.mse_loss
Задание №2 (0.5 балла, сдается в контесте)¶
Реализация шага DDIM¶
Шаг с помощью DDIM с $\mathbf{x}_t$ на $\mathbf{x}_s$ можно сделать следующим образом:
$ \mathbf{x}_s = DDIM(\epsilon_\theta, \mathbf{x}_t, t, s) = \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t} \right) + \sigma_s \epsilon_\theta $
Вам нужно реализовать эту формулу в уже готовом шаблоне ниже. Чтобы корректно выполнить задание, вам нужно задать $\alpha_t$ и $\sigma_t$ имея DDIMScheduler. **Обратите внимание на аттрибут *scheduler.alphas_cumprod***, который задает $\bar\alpha_{t} = \prod^t_{i=1} (1-\beta_i)$ в классической DDPM формулировке: Denoising Diffusion Probabilistic Models.
def ddim_solver_step(model_output, x_t, t, s, scheduler):
"""
Шаг DDIM солвера для VP процесса зашумления и eps-prediction модели
params:
model_output: torch.Tensor[B, 4, 64, 64] - предсказание модели - шум eps
x_t: torch.Tensor[B, 4, 64, 64] - сэмплы на шаге t
t: torch.Tensor[B] - номер текущего шага
s: torch.Tensor[B] - номер следующего шага
scheduler: DDIMScheduler - расписание диффузионного процесса, чтобы получить alpha и sigma
"""
alphas = torch.sqrt(scheduler.alphas_cumprod)
sigmas = torch.sqrt(1 - scheduler.alphas_cumprod)
sigmas_s = extract_into_tensor(sigmas, s, x_t.shape)
alphas_s = extract_into_tensor(alphas, s, x_t.shape)
sigmas_t = extract_into_tensor(sigmas, t, x_t.shape)
alphas_t = extract_into_tensor(alphas, t, x_t.shape)
# Выставляем крайние значения alpha и sigma, чтобы выполнялись граничные условия
alphas_s[s == 0] = 1.0
sigmas_s[s == 0] = 0.0
alphas_t[t == 0] = 1.0
sigmas_t[t == 0] = 0.0
x_0 = (x_t - sigmas_t * model_output) / alphas_t # x0 оценка на шаге t
x_s = alphas_s * x_0 + sigmas_s * model_output # Переход на шаг s
return x_s
Реализация процесса зашумления (q sample)¶
Аналогично, нам нужен процесс зашумления $q(\mathbf{x}_t | \mathbf{x}_0)= {N}(\mathbf{x}_t | \alpha_t \mathbf{x}_0, \sigma^2_t I)$
$\mathbf{x}_t = \alpha_t \mathbf{x}_0 + \sigma_t \epsilon$, где $\epsilon{\sim} {N}(0, I)$
def q_sample(x, t, scheduler, noise=None):
alphas = torch.sqrt(scheduler.alphas_cumprod)
sigmas = torch.sqrt(1 - scheduler.alphas_cumprod)
if noise is None:
noise = torch.randn_like(x)
sigmas_t = extract_into_tensor(sigmas, t, x.shape)
alphas_t = extract_into_tensor(alphas, t, x.shape)
alphas_t[t == 0] = 1.0
sigmas_t[t == 0] = 0.0
x_t = alphas_t * x + sigmas_t * noise
return x_t
Consistency Training¶
Обучение консистенси моделей без учителя называется Consistency Training (CT). В таком случае CM можно рассматривать как отдельный вид генеративных моделей. Давайте начнем именно с этого подхода и обучим нашу первую консистенси модель на базе SD1.5.
Задание №3¶
Задание №3.1 (0.5 балла, сдается в контесте)¶
В консиcтенси дистилляции модель учителя используется для получения второй точки на траектории ODE. Можем ли мы попробовать оценить соседнюю точку аналитически?
Вам предлагается вывести это самим, используя формулу DDIM шага выше и вспомнив, как мы оцениваем скор функции в denoising score matching-e:
$\epsilon_\theta(x_t, t) = - \sigma_t s_\theta(x_t, t)$
$s_\theta(x_t, t) \approx \nabla_{x_t} \log q(x_t) = \mathop{\mathbb{E}}_{\mathbf{x}\sim p_{data}}\left [ \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t | \mathbf{x}) \vert \mathbf{x}_t \right ] \approx \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t \vert \mathbf{x})$
< YOUR DERIVATION HERE >
Я много что попробовал сдать в контесте, но зашла только эта формула.
Имеем процесс зашумления для $x_t$:
$x_t = \alpha_t x_0 + \sigma_t \epsilon$
Выражаем отсюда $\epsilon$:
$\epsilon = \frac{x_t - \alpha_t x_0}{\sigma_t}$
Подставляем в процесс зашумления для $x_s$ и упрощаем:
$x_s = \alpha_s x_0 + \sigma_s \epsilon$
$x_s = \alpha_s x_0 + \sigma_s \frac{x_t - \alpha_t x_0}{\sigma_t}$
$x_s = \frac{\sigma_s}{\sigma_t} x_t + (\alpha_s - \frac{\sigma_s}{\sigma_t} \alpha_t) x_0$
Если возникнут трудность, можно обратиться к оригинальной статье.
Теперь реализуем то, что у вас получилось в функции ниже.
def get_xs_from_xt_naive(
x_0, x_t, t, s, # Не все эти аргументы могут быть вам нужны
scheduler,
noise=None,
**kwargs
):
"""
Получение точки x_s в CT режиме, т.е., аналитически.
"""
alphas = torch.sqrt(scheduler.alphas_cumprod)
sigmas = torch.sqrt(1 - scheduler.alphas_cumprod)
sigmas_t = extract_into_tensor(sigmas, t, x_0.shape)
alphas_t = extract_into_tensor(alphas, t, x_0.shape)
sigmas_s = extract_into_tensor(sigmas, s, x_0.shape)
alphas_s = extract_into_tensor(alphas, s, x_0.shape)
alphas_t[t == 0] = 1.0
sigmas_t[t == 0] = 0.0
alphas_s[s == 0] = 1.0
sigmas_s[s == 0] = 0.0
if x_t is None:
x_s = q_sample(x_0, t, scheduler, noise)
else:
x_s = (sigmas_s / sigmas_t) * x_t + (alphas_s - (sigmas_s / sigmas_t) * alphas_t) * x_0
return x_s
Задание №3.2¶
Ниже предстален шаблон функции, которая считает лосс для консистенси моделей. Вам нужно правильно заполнить пропуски, чтобы получилась корректная функция.
def cm_loss_template(
latents, prompt_embeds, # батч латентов и текстовых эмбедов
unet, scheduler,
# Функции, которые будем постепенно менять из задания к заданию
loss_fn: callable,
get_boundary_timesteps: callable,
get_xs_from_xt: callable,
num_timesteps=1000,
step_size=20, # Указываем с каким интервалом берем шаги s и t.
):
# Сэмплируем случайные шаги t для каждого элемента батча t ~ U[step_size-1, 999]
assert num_timesteps == 1000
num_intervals = num_timesteps // step_size
index = torch.randint(1, num_intervals, (len(latents),), device=latents.device).long() # [1, num_intervals]
t = step_size * index - 1
s = torch.clamp(t - step_size, min=0)
boundary_timesteps = get_boundary_timesteps(
s, num_timesteps=num_timesteps
)
# Сэмплируем x_t
noise = torch.randn_like(latents)
x_t = q_sample(latents, t, scheduler, noise)
# with <YOUR CODE HERE>: # для реализации mixed-precision обучения в задании №4
with torch.amp.autocast("cuda", torch.float16): # Mixed precision
noise_pred = unet(x_t.float(), t,
encoder_hidden_states=prompt_embeds.float(),
).sample
# Получаем оценку в граничной точке для x_t
boundary_pred = ddim_solver_step(
model_output=noise_pred, x_t=x_t, t=t, s=boundary_timesteps, scheduler=scheduler
)
# Получаем сэмпл x_s из x_t
x_s = get_xs_from_xt(
latents, x_t, t, s,
scheduler,
prompt_embeds=prompt_embeds,
noise=noise,
)
# Предсказание "таргет моделью"
with torch.no_grad(), torch.amp.autocast("cuda", torch.float16):
target_noise_pred = unet(x_s, s, encoder_hidden_states=prompt_embeds).sample
# Получаем оценку в граничной точке для x_s
boundary_target = ddim_solver_step(
model_output=target_noise_pred, x_t=x_s, t=s, s=boundary_timesteps, scheduler=scheduler
)
loss = loss_fn(boundary_pred, boundary_target)
return loss
import functools
def get_zero_boundary_timesteps(t, **kwargs):
"""
Определяем шаги где будут срабатывать граничные условия.
Для классических СM это t=0.
"""
return torch.zeros_like(t)
ct_loss = functools.partial(
cm_loss_template,
loss_fn=mse_loss,
get_boundary_timesteps=get_zero_boundary_timesteps,
get_xs_from_xt=get_xs_from_xt_naive
)
assert cm_unet.active_adapter == 'ct'
Задание №4¶
Эффективное обучение¶
Данное задание рассчитано на успешное выполнение на colab с бесплатной Tesla T4 c 15GB VRAM. Однако учить даже относительно небольшие T2I модели масштаба SD1.5 уже на коллабе в лоб проблематично.
Для этого нам нужно применить ряд инженерных техник, чтобы уместиться в данный бюджет и учиться за разумное время.
Список техник
- Включить gradient checkpointing для обучемой модели
- Добавить LoRA (Low Rank Adapters) адаптеры, чтобы учить не все веса, а только 10% добавочных весов
- Использовать gradient accumulation, чтобы делать итерацию обучения по бОльшему батчу, чем влезает по памяти
- Добавить mixed precision FP16/FP32 обучение модели для скорости. Обычно еще и память экономится, но в случае LoRA обучения + gradient checkpointing на память сильно влиять не должно, но зато станет быстрее.
- Мульти-GPU обучение - распределение вычислений по нескольким GPU.
1-2) Мы уже применили за вас выше
3-4) Предстоит реализовать вам самим в соотвествующей секции ниже
5 ) Недоступно, так как работаем на одной карточке
Обучающий цикл¶
Вам дан код обучения модель в полной точности (FP32) c батчом 8. К сожалению, на Tesla T4 мы не влезем по памяти. Поэтому в ячейке ниже вам нужно модифицировать цикл, чтобы он работал в mixed precision FP16 и добавить gradient accumulation.
Про реализацию mixed-precision в pytorch можно перейти по ссылке: Mixed-precision обучение
Обратите внимание: вам еще нужно добавить одну строчку кода в cm_loss_template в соответствующем плейсхолдере.
Замечание: В начале обучения значения лосса должны быть в окрестности 0.0007-0.001. Ничего страшного, что лосс не падает, для CM это нормально. В конце обучения лосс может доходить до 0.005-0.01
from torch.cuda.amp import autocast, GradScaler
def train_loop(model, pipe, train_dataloader, optimizer, loss_fn, num_grad_accum=1):
torch.cuda.empty_cache()
# Создаем скейлер для mixed precision
scaler = GradScaler()
# Итерация по батчам
for i, batch in enumerate(tqdm(train_dataloader)):
latents, prompt_embeds = prepare_batch(batch, pipe)
# Forward + backward с учетом gradient accumulation
with autocast(dtype=torch.float16): # Включаем mixed precision
loss = loss_fn(latents, prompt_embeds, model, pipe.scheduler) / num_grad_accum
# Backward с GradScaler
scaler.scale(loss).backward()
# Обновляем параметры каждые num_grad_accum шагов
if (i + 1) % num_grad_accum == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True) # Сбрасываем градиенты
# Логирование
if i % num_grad_accum == 0 or (i + 1) % num_grad_accum == 0:
print(f"Step {i + 1}, Loss: {loss.detach().item() * num_grad_accum}")
num_grad_accum = 2 # обновляем параметры каждые 2 шага
train_loop(cm_unet, pipe, train_dataloader, optimizer, ct_loss, num_grad_accum)
<ipython-input-17-04d1c509c8fc>:7: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = GradScaler()
0%| | 0/625 [00:00<?, ?it/s]
<ipython-input-17-04d1c509c8fc>:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with autocast(dtype=torch.float16): # Включаем mixed precision
Step 1, Loss: 0.0009964737109839916 Step 2, Loss: 0.0007123646792024374 Step 3, Loss: 0.0025129171553999186 Step 4, Loss: 0.002672075293958187 Step 5, Loss: 0.0012879869900643826 Step 6, Loss: 0.0025040223263204098 Step 7, Loss: 0.0009721919195726514 Step 8, Loss: 0.0007318910211324692 Step 9, Loss: 0.0028819323051720858 Step 10, Loss: 0.0010622225236147642 Step 11, Loss: 0.0020778956823050976 Step 12, Loss: 0.0010942246299237013 Step 13, Loss: 0.0012017062399536371 Step 14, Loss: 0.0007995384512469172 Step 15, Loss: 0.0006579371402040124 Step 16, Loss: 0.0011498258681967854 Step 17, Loss: 0.0008381439838558435 Step 18, Loss: 0.0013155571650713682 Step 19, Loss: 0.001420671702362597 Step 20, Loss: 0.0007696166867390275 Step 21, Loss: 0.0008418669458478689 Step 22, Loss: 0.0015048839850351214 Step 23, Loss: 0.0008957093814387918 Step 24, Loss: 0.0025862189941108227 Step 25, Loss: 0.001034852466545999 Step 26, Loss: 0.00190110900439322 Step 27, Loss: 0.0006745924474671483 Step 28, Loss: 0.0020768148824572563 Step 29, Loss: 0.0013370923697948456 Step 30, Loss: 0.0008536883979104459 Step 31, Loss: 0.001032006461173296 Step 32, Loss: 0.0017720642499625683 Step 33, Loss: 0.0009183366782963276 Step 34, Loss: 0.0014106683665886521 Step 35, Loss: 0.0012036757543683052 Step 36, Loss: 0.001591663807630539 Step 37, Loss: 0.0030212663114070892 Step 38, Loss: 0.0023587362375110388 Step 39, Loss: 0.0014438326470553875 Step 40, Loss: 0.0007154847262427211 Step 41, Loss: 0.001517471857368946 Step 42, Loss: 0.0009275604970753193 Step 43, Loss: 0.0016475626034662127 Step 44, Loss: 0.0015158411115407944 Step 45, Loss: 0.0010888006072491407 Step 46, Loss: 0.0008507425663992763 Step 47, Loss: 0.000835998565889895 Step 48, Loss: 0.002443537348881364 Step 49, Loss: 0.0011052710469812155 Step 50, Loss: 0.001367520191706717 Step 51, Loss: 0.0014835491310805082 Step 52, Loss: 0.0021522396709769964 Step 53, Loss: 0.001071461010724306 Step 54, Loss: 0.0008708474924787879 Step 55, Loss: 0.0016729355556890368 Step 56, Loss: 0.0012967146467417479 Step 57, Loss: 0.0024283856619149446 Step 58, Loss: 0.004348785616457462 Step 59, Loss: 0.0018030828796327114 Step 60, Loss: 0.005939250811934471 Step 61, Loss: 0.0015603761421516538 Step 62, Loss: 0.0015945896739140153 Step 63, Loss: 0.001367853139527142 Step 64, Loss: 0.0010217612143605947 Step 65, Loss: 0.0018260078504681587 Step 66, Loss: 0.0016675188671797514 Step 67, Loss: 0.002414319198578596 Step 68, Loss: 0.0022158417850732803 Step 69, Loss: 0.0013161207316443324 Step 70, Loss: 0.0009690714068710804 Step 71, Loss: 0.0037761759012937546 Step 72, Loss: 0.0013274485245347023 Step 73, Loss: 0.0017683268524706364 Step 74, Loss: 0.0022360978182405233 Step 75, Loss: 0.0018206157255917788 Step 76, Loss: 0.0019233342027291656 Step 77, Loss: 0.0019138399511575699 Step 78, Loss: 0.0018020558636635542 Step 79, Loss: 0.0013781136367470026 Step 80, Loss: 0.003604952711611986 Step 81, Loss: 0.002405259758234024 Step 82, Loss: 0.0013625816209241748 Step 83, Loss: 0.00869310274720192 Step 84, Loss: 0.00275280955247581 Step 85, Loss: 0.004032185301184654 Step 86, Loss: 0.00870254822075367 Step 87, Loss: 0.002071268390864134 Step 88, Loss: 0.0018498199060559273 Step 89, Loss: 0.00991445779800415 Step 90, Loss: 0.0030165230855345726 Step 91, Loss: 0.003333171596750617 Step 92, Loss: 0.003185434266924858 Step 93, Loss: 0.0021970977541059256 Step 94, Loss: 0.003695380873978138 Step 95, Loss: 0.013177640736103058 Step 96, Loss: 0.005383758805692196 Step 97, Loss: 0.005162822548300028 Step 98, Loss: 0.005881481803953648 Step 99, Loss: 0.005417659878730774 Step 100, Loss: 0.007452243007719517 Step 101, Loss: 0.008300170302391052 Step 102, Loss: 0.006970586255192757 Step 103, Loss: 0.00356032932177186 Step 104, Loss: 0.004906745627522469 Step 105, Loss: 0.0038089922163635492 Step 106, Loss: 0.0015337818767875433 Step 107, Loss: 0.00300682638771832 Step 108, Loss: 0.008650153875350952 Step 109, Loss: 0.004231759812682867 Step 110, Loss: 0.0029275708366185427 Step 111, Loss: 0.003925336059182882 Step 112, Loss: 0.004475311376154423 Step 113, Loss: 0.0034971064887940884 Step 114, Loss: 0.004230810329318047 Step 115, Loss: 0.0024554363917559385 Step 116, Loss: 0.0021734843030571938 Step 117, Loss: 0.008273965679109097 Step 118, Loss: 0.002536300104111433 Step 119, Loss: 0.003144968766719103 Step 120, Loss: 0.0033394754864275455 Step 121, Loss: 0.0019214244093745947 Step 122, Loss: 0.001339460490271449 Step 123, Loss: 0.003105826210230589 Step 124, Loss: 0.0019138716161251068 Step 125, Loss: 0.0009612067369744182 Step 126, Loss: 0.0028296932578086853 Step 127, Loss: 0.0021417844109237194 Step 128, Loss: 0.004075353033840656 Step 129, Loss: 0.004901350475847721 Step 130, Loss: 0.0011327839456498623 Step 131, Loss: 0.004618014208972454 Step 132, Loss: 0.002003858797252178 Step 133, Loss: 0.003282646182924509 Step 134, Loss: 0.0014006216078996658 Step 135, Loss: 0.0014191106893122196 Step 136, Loss: 0.0037141223438084126 Step 137, Loss: 0.002633387455716729 Step 138, Loss: 0.002554066013544798 Step 139, Loss: 0.0018516054842621088 Step 140, Loss: 0.0018092331010848284 Step 141, Loss: 0.0017135003581643105 Step 142, Loss: 0.0036260229535400867 Step 143, Loss: 0.0030195917934179306 Step 144, Loss: 0.0009842927101999521 Step 145, Loss: 0.0015842400025576353 Step 146, Loss: 0.0017116828821599483 Step 147, Loss: 0.0015447353944182396 Step 148, Loss: 0.001342438394203782 Step 149, Loss: 0.002777844201773405 Step 150, Loss: 0.001807103049941361 Step 151, Loss: 0.0025249419268220663 Step 152, Loss: 0.0017091265181079507 Step 153, Loss: 0.0024985510390251875 Step 154, Loss: 0.001993537414819002 Step 155, Loss: 0.0019271315541118383 Step 156, Loss: 0.0018232737202197313 Step 157, Loss: 0.002669207751750946 Step 158, Loss: 0.0013759227003902197 Step 159, Loss: 0.0024884392041713 Step 160, Loss: 0.003064761171117425 Step 161, Loss: 0.0016928670229390264 Step 162, Loss: 0.0018594666616991162 Step 163, Loss: 0.0022970836143940687 Step 164, Loss: 0.0031568421982228756 Step 165, Loss: 0.007324790116399527 Step 166, Loss: 0.001695355400443077 Step 167, Loss: 0.0017826779512688518 Step 168, Loss: 0.0011040661484003067 Step 169, Loss: 0.0010679103434085846 Step 170, Loss: 0.0011576625984162092 Step 171, Loss: 0.003895317669957876 Step 172, Loss: 0.001107738702557981 Step 173, Loss: 0.002163609955459833 Step 174, Loss: 0.0024494314566254616 Step 175, Loss: 0.001972742145881057 Step 176, Loss: 0.001785499625839293 Step 177, Loss: 0.002004774287343025 Step 178, Loss: 0.0009915400296449661 Step 179, Loss: 0.0015638095792382956 Step 180, Loss: 0.0013448246754705906 Step 181, Loss: 0.0031648697331547737 Step 182, Loss: 0.001704327529296279 Step 183, Loss: 0.002551598474383354 Step 184, Loss: 0.002767860423773527 Step 185, Loss: 0.002158526098355651 Step 186, Loss: 0.001088866381905973 Step 187, Loss: 0.0024332161992788315 Step 188, Loss: 0.001555718365125358 Step 189, Loss: 0.0014911929611116648 Step 190, Loss: 0.0012259191134944558 Step 191, Loss: 0.0019990727305412292 Step 192, Loss: 0.004777118563652039 Step 193, Loss: 0.0017879370134323835 Step 194, Loss: 0.0022225005086511374 Step 195, Loss: 0.0018355753272771835 Step 196, Loss: 0.0021953212562948465 Step 197, Loss: 0.0019467687234282494 Step 198, Loss: 0.0033462434075772762 Step 199, Loss: 0.0018740312661975622 Step 200, Loss: 0.0010305580217391253 Step 201, Loss: 0.0024460656568408012 Step 202, Loss: 0.0010713809169828892 Step 203, Loss: 0.004301274195313454 Step 204, Loss: 0.002050131093710661 Step 205, Loss: 0.0015187339158728719 Step 206, Loss: 0.0028489278629422188 Step 207, Loss: 0.0027420278638601303 Step 208, Loss: 0.002065301174297929 Step 209, Loss: 0.0020583991426974535 Step 210, Loss: 0.0030916407704353333 Step 211, Loss: 0.003561109770089388 Step 212, Loss: 0.0018161950865760446 Step 213, Loss: 0.001744080800563097 Step 214, Loss: 0.003127105999737978 Step 215, Loss: 0.0030535277910530567 Step 216, Loss: 0.0025577410124242306 Step 217, Loss: 0.002055512275546789 Step 218, Loss: 0.0013864929787814617 Step 219, Loss: 0.0029660311993211508 Step 220, Loss: 0.0013375321868807077 Step 221, Loss: 0.002898484468460083 Step 222, Loss: 0.0024262629449367523 Step 223, Loss: 0.003130171447992325 Step 224, Loss: 0.0014638151042163372 Step 225, Loss: 0.00363878789357841 Step 226, Loss: 0.003392276354134083 Step 227, Loss: 0.002066464629024267 Step 228, Loss: 0.003555938834324479 Step 229, Loss: 0.0033692405559122562 Step 230, Loss: 0.0010493635199964046 Step 231, Loss: 0.0010821252362802625 Step 232, Loss: 0.0017637109849601984 Step 233, Loss: 0.0012440495193004608 Step 234, Loss: 0.0016341344453394413 Step 235, Loss: 0.0015762588009238243 Step 236, Loss: 0.0010235632071271539 Step 237, Loss: 0.001646367833018303 Step 238, Loss: 0.002081636106595397 Step 239, Loss: 0.0012786537408828735 Step 240, Loss: 0.002571119461208582 Step 241, Loss: 0.0013509748969227076 Step 242, Loss: 0.003147643292322755 Step 243, Loss: 0.0014911983162164688 Step 244, Loss: 0.0029995788354426622 Step 245, Loss: 0.0022395257838070393 Step 246, Loss: 0.0013257043901830912 Step 247, Loss: 0.0016677497187629342 Step 248, Loss: 0.001462545245885849 Step 249, Loss: 0.003457159036770463 Step 250, Loss: 0.0013242086861282587 Step 251, Loss: 0.0021175735164433718 Step 252, Loss: 0.0023571918718516827 Step 253, Loss: 0.0030575694981962442 Step 254, Loss: 0.0024516936391592026 Step 255, Loss: 0.002214903710409999 Step 256, Loss: 0.0020048075821250677 Step 257, Loss: 0.0030402978882193565 Step 258, Loss: 0.0031706318259239197 Step 259, Loss: 0.0019083842635154724 Step 260, Loss: 0.0021437020041048527 Step 261, Loss: 0.001716109924018383 Step 262, Loss: 0.0012771545443683863 Step 263, Loss: 0.0017771539278328419 Step 264, Loss: 0.0015819550026208162 Step 265, Loss: 0.002498017158359289 Step 266, Loss: 0.0016480679623782635 Step 267, Loss: 0.0022397267166525126 Step 268, Loss: 0.002551595214754343 Step 269, Loss: 0.001356970053166151 Step 270, Loss: 0.002087733708322048 Step 271, Loss: 0.0026639692950993776 Step 272, Loss: 0.0052254656329751015 Step 273, Loss: 0.0013680905103683472 Step 274, Loss: 0.0036784247495234013 Step 275, Loss: 0.000992199988104403 Step 276, Loss: 0.0018863962031900883 Step 277, Loss: 0.0029394528828561306 Step 278, Loss: 0.001596520422026515 Step 279, Loss: 0.004238415509462357 Step 280, Loss: 0.0030783922411501408 Step 281, Loss: 0.0021919207647442818 Step 282, Loss: 0.002420907374471426 Step 283, Loss: 0.0017966690938919783 Step 284, Loss: 0.001461828825995326 Step 285, Loss: 0.0036598462611436844 Step 286, Loss: 0.0012443051673471928 Step 287, Loss: 0.00148701760917902 Step 288, Loss: 0.0017320663901045918 Step 289, Loss: 0.0014320281334221363 Step 290, Loss: 0.008277904242277145 Step 291, Loss: 0.0017746041994541883 Step 292, Loss: 0.001563466852530837 Step 293, Loss: 0.002196545246988535 Step 294, Loss: 0.003776966128498316 Step 295, Loss: 0.001591964508406818 Step 296, Loss: 0.001984924543648958 Step 297, Loss: 0.002549830824136734 Step 298, Loss: 0.0013229760807007551 Step 299, Loss: 0.0014838275965303183 Step 300, Loss: 0.0012145370710641146 Step 301, Loss: 0.002189289778470993 Step 302, Loss: 0.0022014169953763485 Step 303, Loss: 0.001216724282130599 Step 304, Loss: 0.0013273911317810416 Step 305, Loss: 0.0013471723068505526 Step 306, Loss: 0.002097218995913863 Step 307, Loss: 0.0031583341769874096 Step 308, Loss: 0.0027093582320958376 Step 309, Loss: 0.004970692563802004 Step 310, Loss: 0.0018247850239276886 Step 311, Loss: 0.00376988691277802 Step 312, Loss: 0.0016975709004327655 Step 313, Loss: 0.0009460895671509206 Step 314, Loss: 0.003457896178588271 Step 315, Loss: 0.0015740500530228019 Step 316, Loss: 0.0018885633908212185 Step 317, Loss: 0.002235714579001069 Step 318, Loss: 0.0029280800372362137 Step 319, Loss: 0.0021540713496506214 Step 320, Loss: 0.0022602220997214317 Step 321, Loss: 0.0018740276573225856 Step 322, Loss: 0.0022994857281446457 Step 323, Loss: 0.0013489238917827606 Step 324, Loss: 0.00139906897675246 Step 325, Loss: 0.0030182464979588985 Step 326, Loss: 0.0015988049563020468 Step 327, Loss: 0.0018778664525598288 Step 328, Loss: 0.001625601202249527 Step 329, Loss: 0.0018519361037760973 Step 330, Loss: 0.0017070379108190536 Step 331, Loss: 0.0021748319268226624 Step 332, Loss: 0.0017367968102917075 Step 333, Loss: 0.0014846296980977058 Step 334, Loss: 0.002475373214110732 Step 335, Loss: 0.002954889554530382 Step 336, Loss: 0.0046874405816197395 Step 337, Loss: 0.00242308946326375 Step 338, Loss: 0.001720296568237245 Step 339, Loss: 0.0010912759462371469 Step 340, Loss: 0.0018609871622174978 Step 341, Loss: 0.00340099073946476 Step 342, Loss: 0.004037333186715841 Step 343, Loss: 0.0028379643335938454 Step 344, Loss: 0.0013051190180703998 Step 345, Loss: 0.002631761133670807 Step 346, Loss: 0.0025198571383953094 Step 347, Loss: 0.0032972507178783417 Step 348, Loss: 0.003289341926574707 Step 349, Loss: 0.0017978892428800464 Step 350, Loss: 0.0033561696764081717 Step 351, Loss: 0.0021787723526358604 Step 352, Loss: 0.003375221509486437 Step 353, Loss: 0.001669022487476468 Step 354, Loss: 0.002778403926640749 Step 355, Loss: 0.0023496164940297604 Step 356, Loss: 0.0018149161478504539 Step 357, Loss: 0.0014565347228199244 Step 358, Loss: 0.0016320422291755676 Step 359, Loss: 0.002429002895951271 Step 360, Loss: 0.0029447644483298063 Step 361, Loss: 0.0025833663530647755 Step 362, Loss: 0.0026737430598586798 Step 363, Loss: 0.0021661401260644197 Step 364, Loss: 0.0017511245096102357 Step 365, Loss: 0.002903490327298641 Step 366, Loss: 0.005947783123701811 Step 367, Loss: 0.004735832568258047 Step 368, Loss: 0.0019470665138214827 Step 369, Loss: 0.004013554193079472 Step 370, Loss: 0.002077719196677208 Step 371, Loss: 0.002286328934133053 Step 372, Loss: 0.003162218490615487 Step 373, Loss: 0.003371443133801222 Step 374, Loss: 0.0022568991407752037 Step 375, Loss: 0.0013082942459732294 Step 376, Loss: 0.0021568434312939644 Step 377, Loss: 0.002213746774941683 Step 378, Loss: 0.0030593108385801315 Step 379, Loss: 0.0019463550997897983 Step 380, Loss: 0.003151160664856434 Step 381, Loss: 0.0022312854416668415 Step 382, Loss: 0.0017455383203923702 Step 383, Loss: 0.0027218933682888746 Step 384, Loss: 0.0014385171234607697 Step 385, Loss: 0.004356844816356897 Step 386, Loss: 0.0030530733056366444 Step 387, Loss: 0.0015455837128683925 Step 388, Loss: 0.0036600164603441954 Step 389, Loss: 0.003361268900334835 Step 390, Loss: 0.0047602481208741665 Step 391, Loss: 0.0034189848229289055 Step 392, Loss: 0.0022142487578094006 Step 393, Loss: 0.0018446831963956356 Step 394, Loss: 0.0021419369149953127 Step 395, Loss: 0.0016223136335611343 Step 396, Loss: 0.002930941991508007 Step 397, Loss: 0.002225317992269993 Step 398, Loss: 0.0027877832762897015 Step 399, Loss: 0.002290947362780571 Step 400, Loss: 0.0028579893987625837 Step 401, Loss: 0.0017973837675526738 Step 402, Loss: 0.0019481240306049585 Step 403, Loss: 0.0030283827800303698 Step 404, Loss: 0.0019655197393149137 Step 405, Loss: 0.0027977675199508667 Step 406, Loss: 0.0035928445868194103 Step 407, Loss: 0.001847823616117239 Step 408, Loss: 0.001033484935760498 Step 409, Loss: 0.002944400068372488 Step 410, Loss: 0.0029716547578573227 Step 411, Loss: 0.0015361449914053082 Step 412, Loss: 0.0024315500631928444 Step 413, Loss: 0.0021710386499762535 Step 414, Loss: 0.0012272449675947428 Step 415, Loss: 0.00263773906044662 Step 416, Loss: 0.0038949991576373577 Step 417, Loss: 0.0014056290965527296 Step 418, Loss: 0.001367239747196436 Step 419, Loss: 0.003612218890339136 Step 420, Loss: 0.0015791754703968763 Step 421, Loss: 0.0046881032176315784 Step 422, Loss: 0.002825289499014616 Step 423, Loss: 0.0012546565849334002 Step 424, Loss: 0.0020847772248089314 Step 425, Loss: 0.002753379987552762 Step 426, Loss: 0.00311181228607893 Step 427, Loss: 0.0018361861584708095 Step 428, Loss: 0.0020177848637104034 Step 429, Loss: 0.0017246806528419256 Step 430, Loss: 0.0022823826875537634 Step 431, Loss: 0.0038195070810616016 Step 432, Loss: 0.002087909961119294 Step 433, Loss: 0.001844338490627706 Step 434, Loss: 0.0013272862415760756 Step 435, Loss: 0.0024121454916894436 Step 436, Loss: 0.004444989841431379 Step 437, Loss: 0.001688591786660254 Step 438, Loss: 0.0016547583509236574 Step 439, Loss: 0.0028140549547970295 Step 440, Loss: 0.0022476916201412678 Step 441, Loss: 0.0026859110221266747 Step 442, Loss: 0.0029036798514425755 Step 443, Loss: 0.002507369965314865 Step 444, Loss: 0.002403222257271409 Step 445, Loss: 0.0014040693640708923 Step 446, Loss: 0.002955233445391059 Step 447, Loss: 0.0043233102187514305 Step 448, Loss: 0.004444340243935585 Step 449, Loss: 0.002760730916634202 Step 450, Loss: 0.0037729348987340927 Step 451, Loss: 0.0037752510979771614 Step 452, Loss: 0.002742704004049301 Step 453, Loss: 0.001664300449192524 Step 454, Loss: 0.003079385496675968 Step 455, Loss: 0.001461958047002554 Step 456, Loss: 0.002757724840193987 Step 457, Loss: 0.0021401511039584875 Step 458, Loss: 0.0033850609324872494 Step 459, Loss: 0.003520863363519311 Step 460, Loss: 0.003346397541463375 Step 461, Loss: 0.0019605071283876896 Step 462, Loss: 0.003859851509332657 Step 463, Loss: 0.003244594670832157 Step 464, Loss: 0.004167188890278339 Step 465, Loss: 0.0025255850050598383 Step 466, Loss: 0.003343079937621951 Step 467, Loss: 0.0036160782910883427 Step 468, Loss: 0.004407159052789211 Step 469, Loss: 0.0034948645625263453 Step 470, Loss: 0.0021185046061873436 Step 471, Loss: 0.002002623863518238 Step 472, Loss: 0.0018690098077058792 Step 473, Loss: 0.004508321639150381 Step 474, Loss: 0.0017977735260501504 Step 475, Loss: 0.0017795597668737173 Step 476, Loss: 0.0014956528320908546 Step 477, Loss: 0.001654409570619464 Step 478, Loss: 0.002867832314223051 Step 479, Loss: 0.0021172226406633854 Step 480, Loss: 0.0013408382656052709 Step 481, Loss: 0.00435650022700429 Step 482, Loss: 0.0030726974364370108 Step 483, Loss: 0.0013296615798026323 Step 484, Loss: 0.001793580362573266 Step 485, Loss: 0.0016964077949523926 Step 486, Loss: 0.001986682415008545 Step 487, Loss: 0.0020661265589296818 Step 488, Loss: 0.0026734794955700636 Step 489, Loss: 0.003145786002278328 Step 490, Loss: 0.0022016228176653385 Step 491, Loss: 0.002660672180354595 Step 492, Loss: 0.001197907142341137 Step 493, Loss: 0.003564678831025958 Step 494, Loss: 0.002485886448994279 Step 495, Loss: 0.0019238153472542763 Step 496, Loss: 0.0022778897546231747 Step 497, Loss: 0.002266446128487587 Step 498, Loss: 0.004147983156144619 Step 499, Loss: 0.004518461879342794 Step 500, Loss: 0.002260998822748661 Step 501, Loss: 0.0029638917185366154 Step 502, Loss: 0.0026786173693835735 Step 503, Loss: 0.0016313818050548434 Step 504, Loss: 0.0017759317997843027 Step 505, Loss: 0.0021710728760808706 Step 506, Loss: 0.0029801903292536736 Step 507, Loss: 0.0018787817098200321 Step 508, Loss: 0.004778842441737652 Step 509, Loss: 0.0021530394442379475 Step 510, Loss: 0.004105462692677975 Step 511, Loss: 0.003464809153228998 Step 512, Loss: 0.0026504206471145153 Step 513, Loss: 0.0022748950868844986 Step 514, Loss: 0.001675811829045415 Step 515, Loss: 0.0021095022093504667 Step 516, Loss: 0.0020678124856203794 Step 517, Loss: 0.0029012225568294525 Step 518, Loss: 0.004787777550518513 Step 519, Loss: 0.0035675661638379097 Step 520, Loss: 0.0033075539395213127 Step 521, Loss: 0.002297051018103957 Step 522, Loss: 0.0034762569703161716 Step 523, Loss: 0.0032242024317383766 Step 524, Loss: 0.002810206264257431 Step 525, Loss: 0.0016817448195070028 Step 526, Loss: 0.0018216629978269339 Step 527, Loss: 0.002537459833547473 Step 528, Loss: 0.003333230037242174 Step 529, Loss: 0.001697663450613618 Step 530, Loss: 0.0038388343527913094 Step 531, Loss: 0.002034861594438553 Step 532, Loss: 0.002954866038635373 Step 533, Loss: 0.003910453990101814 Step 534, Loss: 0.0020375121384859085 Step 535, Loss: 0.002452058019116521 Step 536, Loss: 0.0020335500594228506 Step 537, Loss: 0.0019645700231194496 Step 538, Loss: 0.0033477586694061756 Step 539, Loss: 0.004655073396861553 Step 540, Loss: 0.001522570033557713 Step 541, Loss: 0.003043032716959715 Step 542, Loss: 0.0023580617271363735 Step 543, Loss: 0.0036311703734099865 Step 544, Loss: 0.002830999670550227 Step 545, Loss: 0.0026946349535137415 Step 546, Loss: 0.0027369363233447075 Step 547, Loss: 0.0014522189740091562 Step 548, Loss: 0.0018265830585733056 Step 549, Loss: 0.0013534543104469776 Step 550, Loss: 0.001750377588905394 Step 551, Loss: 0.0018752054311335087 Step 552, Loss: 0.003588886931538582 Step 553, Loss: 0.0023917958606034517 Step 554, Loss: 0.002180017763748765 Step 555, Loss: 0.0013729555066674948 Step 556, Loss: 0.002268531359732151 Step 557, Loss: 0.001678610104136169 Step 558, Loss: 0.0031022634357213974 Step 559, Loss: 0.0020591646898537874 Step 560, Loss: 0.003965404815971851 Step 561, Loss: 0.0015510644298046827 Step 562, Loss: 0.0015438924310728908 Step 563, Loss: 0.0026267962530255318 Step 564, Loss: 0.003942787181586027 Step 565, Loss: 0.0023364638909697533 Step 566, Loss: 0.0017311549745500088 Step 567, Loss: 0.0023734201677143574 Step 568, Loss: 0.0023933923803269863 Step 569, Loss: 0.0020020711235702038 Step 570, Loss: 0.001555254915729165 Step 571, Loss: 0.0016916969325393438 Step 572, Loss: 0.001597439288161695 Step 573, Loss: 0.004049327224493027 Step 574, Loss: 0.004352931398898363 Step 575, Loss: 0.002702228259295225 Step 576, Loss: 0.004212164785712957 Step 577, Loss: 0.002407374791800976 Step 578, Loss: 0.0033760140649974346 Step 579, Loss: 0.003392706857994199 Step 580, Loss: 0.0023206849582493305 Step 581, Loss: 0.0013627472799271345 Step 582, Loss: 0.002573030535131693 Step 583, Loss: 0.0023301132023334503 Step 584, Loss: 0.00240900507196784 Step 585, Loss: 0.004537998698651791 Step 586, Loss: 0.0022711679339408875 Step 587, Loss: 0.004056943580508232 Step 588, Loss: 0.002882580505684018 Step 589, Loss: 0.002988439751788974 Step 590, Loss: 0.005434936378151178 Step 591, Loss: 0.0025192226748913527 Step 592, Loss: 0.0028609472792595625 Step 593, Loss: 0.0015962341567501426 Step 594, Loss: 0.004357056692242622 Step 595, Loss: 0.0018941131420433521 Step 596, Loss: 0.0015789249446243048 Step 597, Loss: 0.0023252214305102825 Step 598, Loss: 0.0018447035690769553 Step 599, Loss: 0.002543957205489278 Step 600, Loss: 0.002252105623483658 Step 601, Loss: 0.002090814057737589 Step 602, Loss: 0.002668071072548628 Step 603, Loss: 0.001780823222361505 Step 604, Loss: 0.00167083612177521 Step 605, Loss: 0.0043045347556471825 Step 606, Loss: 0.0017265079077333212 Step 607, Loss: 0.001871362328529358 Step 608, Loss: 0.003699228400364518 Step 609, Loss: 0.00566853117197752 Step 610, Loss: 0.002942024264484644 Step 611, Loss: 0.0027231022249907255 Step 612, Loss: 0.007950660772621632 Step 613, Loss: 0.002311816206201911 Step 614, Loss: 0.002136490074917674 Step 615, Loss: 0.0017656716518104076 Step 616, Loss: 0.001274152658879757 Step 617, Loss: 0.0034350096248090267 Step 618, Loss: 0.001877711503766477 Step 619, Loss: 0.001676612184382975 Step 620, Loss: 0.002426299499347806 Step 621, Loss: 0.002206122037023306 Step 622, Loss: 0.002395118586719036 Step 623, Loss: 0.002627358539029956 Step 624, Loss: 0.0031444155611097813 Step 625, Loss: 0.0015789918834343553
# torch.save(cm_unet.state_dict(), '/content/drive/MyDrive/cv_model/cm_model.pth')
Задание 5¶
Генерация с помощью обученной консистенси модели¶
Настало время погенерировать картинки с помощью нашей модели. Напомним, что мы не можем для консистенси моделей использовать DDIM и другие классические солверы для диффузии. Нам нужен специальный сэмплер для CM, который схематично изображен на картинке ниже:
Чуть более формально:
$x_{t_n} \sim {N}(0, I)$
$for\ t_i \in [t_n, ..., t_1]:$
$\epsilon \leftarrow unet(x_{t_i})$
$x_0 \leftarrow DDIM(\epsilon, x_{t_i}, t_i, 0)$
$x_{t_{i-1}} \leftarrow q(x_{t_{i-1}} | x_0)$
Classifier-free guidance (CFG)
Также вам надо реализовать поддержку CFG в CM сэмплирование. Вспомним формулу:
$\epsilon_w = {\color{blue}{\epsilon_{uncond}}} + w \cdot (\epsilon_{cond} - \epsilon_{uncond})$, где $w \geq 1$
Обратим внимание, что режим "без гайденса" соотвествует $w = 1$, что немного контринтуитивно, но в большинстве реализаций будет встречаться именно такой вид этой формулы.
@torch.no_grad()
def consistency_sampling(
pipe,
prompt,
num_inference_steps=4,
generator=None,
num_images_per_prompt=4,
guidance_scale=1
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
device = pipe._execution_device
# Извлекаем эмбеды из текстовых промптов. Реализуйте вызов pipe.encode_prompt
do_classifier_free_guidance = guidance_scale > 0
prompt_embeds = pipe.encode_prompt(
prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance
)[0]
null_prompt_embeds = pipe.encode_prompt(
[""] * batch_size, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance
)[0]
assert prompt_embeds.dtype == null_prompt_embeds.dtype == torch.float16
# Настраиваем параметры scheduler-a
assert pipe.scheduler.config['timestep_spacing'] == 'trailing'
pipe.scheduler.set_timesteps(num_inference_steps)
# Создаем батч латентов из N(0,I)
latents = torch.randn(
(batch_size * num_images_per_prompt, pipe.unet.in_channels, pipe.unet.sample_size, pipe.unet.sample_size),
generator=generator,
device=device,
dtype=torch.float16,
)
for i, t in enumerate(tqdm(pipe.scheduler.timesteps)):
t = torch.tensor([t] * len(latents)).to(device)
zero_t = torch.tensor([0] * len(latents)).to(device)
cond_noise_pred = pipe.unet(
latents, t, encoder_hidden_states=prompt_embeds
).sample
if do_classifier_free_guidance:
uncond_noise_pred = pipe.unet(
latents, t, encoder_hidden_states=null_prompt_embeds
).sample
noise_pred = uncond_noise_pred + guidance_scale * (cond_noise_pred - uncond_noise_pred)
else:
noise_pred = cond_noise_pred
# Получаем x_0 оценку из x_t
x_0 = ddim_solver_step(
model_output=noise_pred, x_t=latents, t=t, s=zero_t, scheduler=pipe.scheduler
)
if i + 1 < num_inference_steps:
# Переход на следующий шаг
s = pipe.scheduler.timesteps[i+1]
s = torch.tensor([s] * len(latents)).to(device)
latents = q_sample(x=x_0, t=s, scheduler=pipe.scheduler)
else:
# Последний шаг
latents = x_0
latents = latents.half()
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
do_denormalize = [True] * image.shape[0]
image = pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=do_denormalize)
return image
Попробуем сгененировать что-то нашей моделью. Можно поиграться с разными сидами и гайденс скейлами.
Референс, что примерно должно получиться на этом этапе для guidance_scale=2. Как видите, картинки стали почетче, но пока все еще так себе.

pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'ct'
generator = torch.Generator(device="cuda").manual_seed(1)
guidance_scale = 3
# Заменяем генерацию пайплайном на наше сэмплирование.
images = consistency_sampling(
pipe,
prompt,
num_inference_steps=4,
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale
)
visualize_images(images)
/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py:364: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'. return getattr(self.model, name)
0%| | 0/4 [00:00<?, ?it/s]
Consistency Distillation¶
Задание №6¶
Теперь давайте попробуем перейти к постановке дистилляции, где шаг из $x_t$ в $x_s$ будет делаться не аналитически, а c помощью модели учителя.
$\mathbf{x}_t = q(\mathbf{x}_t | \mathbf{x}_0)$
$\mathbf{x}_s = DDIM(\epsilon_\theta(\mathbf{x}_t, t), \mathbf{x}_t, t, s)$
Замечание: В text-to-image генерации classifier-free guidance (CFG) играет очень важную роль для получения хорошего качества с помощью диффузии. CFG меняет траектории ODE и раз нам он важен, то давайте и дистиллировать траектории с CFG.
Поэтому для получения точки $\mathbf{x}_{s}$ мы будем использовать шаг учителя с CFG. Это важное отличие от CT сеттинга - там мы не можем моделировать гайденс.
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32
# Добавляем новые LoRA адаптеры для CD модели
cm_unet.add_adapter("cd", lora_config)
cm_unet.set_adapter("cd")
# Пересоздаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
@torch.no_grad()
def get_xs_from_xt_with_teacher(
x_0, x_t, t, s, # Не все эти аргументы могут быть вам нужны
scheduler,
prompt_embeds,
teacher_unet,
guidance_scale,
**kwargs
):
# Делаем предсказание учителем в кондишион случае: подаем эмбеды текста
cond_noise_pred = teacher_unet(
x_t, t, encoder_hidden_states=prompt_embeds
).sample
# Для CFG нам нужно делать предсказания в unconditional случае.
# Для T2I моделей, мы будем это моделировать предсказаниями для пустого промпта ""
# Извлечем эмбеды из пустого промпта и размножить их до размера батча
uncond_input_ids = pipe.tokenizer(
[""], return_tensors="pt", padding="max_length", max_length=77
).input_ids.to("cuda")
uncond_prompt_embeds = pipe.text_encoder(uncond_input_ids)[0].expand(
*prompt_embeds.shape
)
# Затем прогоняем модель для пустых промптов
uncond_noise_pred = teacher_unet(
x_t, t, encoder_hidden_states=uncond_prompt_embeds
).sample
# Применяем CFG формулу и получаем итоговый предикт учителя
noise_pred = uncond_noise_pred + guidance_scale * (cond_noise_pred - uncond_noise_pred)
# Получаем x_s из x_t
x_s = ddim_solver_step(
model_output=noise_pred, x_t=x_t, t=t, s=s, scheduler=scheduler
)
return x_s
# Сразу зададим внутрь модель учителя и guidance_scale
get_xs_from_xt_with_teacher = functools.partial(
get_xs_from_xt_with_teacher,
teacher_unet=teacher_unet,
guidance_scale=7.5
)
Еще, как показано в работе Improved Techniques for Training Consistency Models. L2 лосс не самый оптимальный выбор для консистенси моделей. Давайте в CD обучении также заменим MSE лосс на pseudo-huber лосс из статьи.
def pseudo_huber_loss(
x: torch.Tensor,
y: torch.Tensor,
c=0.001
):
diff = x - y
loss = torch.sum(torch.sqrt(diff**2 + c**2) - c)
return loss
cd_loss = functools.partial(
cm_loss_template,
loss_fn=pseudo_huber_loss,
get_boundary_timesteps=get_zero_boundary_timesteps,
get_xs_from_xt=get_xs_from_xt_with_teacher
)
assert cm_unet.active_adapter == 'cd'
Теперь обучим модель в CD режиме
Лосс большой поскольку не добавил в huber loss усреднение, но картинки получаются все равно хорошего качества!!!¶
num_grad_accum = 2 # обновляем параметры каждые 2 шага
train_loop(cm_unet, pipe, train_dataloader, optimizer, cd_loss, num_grad_accum)
<ipython-input-17-04d1c509c8fc>:7: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = GradScaler()
0%| | 0/625 [00:00<?, ?it/s]
<ipython-input-17-04d1c509c8fc>:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with autocast(dtype=torch.float16): # Включаем mixed precision
Step 1, Loss: 3067.6396484375 Step 2, Loss: 4501.21533203125 Step 3, Loss: 3528.36181640625 Step 4, Loss: 2838.5224609375 Step 5, Loss: 3635.279296875 Step 6, Loss: 3245.017822265625 Step 7, Loss: 4237.431640625 Step 8, Loss: 3144.990966796875 Step 9, Loss: 4142.67529296875 Step 10, Loss: 3631.2841796875 Step 11, Loss: 3534.72998046875 Step 12, Loss: 1901.096923828125 Step 13, Loss: 2335.93798828125 Step 14, Loss: 2924.720703125 Step 15, Loss: 3095.15673828125 Step 16, Loss: 2338.27099609375 Step 17, Loss: 2354.35693359375 Step 18, Loss: 2751.059814453125 Step 19, Loss: 2736.702392578125 Step 20, Loss: 2632.9658203125 Step 21, Loss: 3111.861328125 Step 22, Loss: 5292.263671875 Step 23, Loss: 2556.02294921875 Step 24, Loss: 2551.59814453125 Step 25, Loss: 3263.38623046875 Step 26, Loss: 2465.36083984375 Step 27, Loss: 2000.357666015625 Step 28, Loss: 2800.03466796875 Step 29, Loss: 2182.60546875 Step 30, Loss: 5640.53857421875 Step 31, Loss: 4463.96875 Step 32, Loss: 5104.5107421875 Step 33, Loss: 2933.88916015625 Step 34, Loss: 3778.8603515625 Step 35, Loss: 4359.3681640625 Step 36, Loss: 3861.1748046875 Step 37, Loss: 3387.4375 Step 38, Loss: 3801.18701171875 Step 39, Loss: 3751.1162109375 Step 40, Loss: 3574.050048828125 Step 41, Loss: 2936.130859375 Step 42, Loss: 2676.213623046875 Step 43, Loss: 3132.67138671875 Step 44, Loss: 4475.865234375 Step 45, Loss: 3380.264892578125 Step 46, Loss: 4189.265625 Step 47, Loss: 3740.886962890625 Step 48, Loss: 2627.85498046875 Step 49, Loss: 2620.080322265625 Step 50, Loss: 2757.20849609375 Step 51, Loss: 2848.977783203125 Step 52, Loss: 3051.163330078125 Step 53, Loss: 2933.139892578125 Step 54, Loss: 3533.27294921875 Step 55, Loss: 4150.51806640625 Step 56, Loss: 2986.15283203125 Step 57, Loss: 3702.853515625 Step 58, Loss: 3069.2958984375 Step 59, Loss: 3301.21435546875 Step 60, Loss: 4020.974853515625 Step 61, Loss: 3972.615966796875 Step 62, Loss: 3032.62890625 Step 63, Loss: 3643.7158203125 Step 64, Loss: 4450.271484375 Step 65, Loss: 3761.891357421875 Step 66, Loss: 3120.98046875 Step 67, Loss: 4099.08447265625 Step 68, Loss: 3424.960693359375 Step 69, Loss: 3724.317138671875 Step 70, Loss: 5375.58984375 Step 71, Loss: 5073.2177734375 Step 72, Loss: 3852.9150390625 Step 73, Loss: 2907.05078125 Step 74, Loss: 3241.89794921875 Step 75, Loss: 5088.7919921875 Step 76, Loss: 3307.97802734375 Step 77, Loss: 3389.269287109375 Step 78, Loss: 4032.13720703125 Step 79, Loss: 4095.20703125 Step 80, Loss: 2576.86669921875 Step 81, Loss: 6043.5419921875 Step 82, Loss: 5501.4296875 Step 83, Loss: 6051.47021484375 Step 84, Loss: 4081.576171875 Step 85, Loss: 3376.9384765625 Step 86, Loss: 3990.83251953125 Step 87, Loss: 3614.535888671875 Step 88, Loss: 6747.9072265625 Step 89, Loss: 5914.2421875 Step 90, Loss: 4253.474609375 Step 91, Loss: 6455.4697265625 Step 92, Loss: 4305.9140625 Step 93, Loss: 8737.1953125 Step 94, Loss: 6071.18212890625 Step 95, Loss: 5942.138671875 Step 96, Loss: 4217.1015625 Step 97, Loss: 2322.83642578125 Step 98, Loss: 3829.908935546875 Step 99, Loss: 3482.105224609375 Step 100, Loss: 5649.88330078125 Step 101, Loss: 9304.447265625 Step 102, Loss: 5281.0556640625 Step 103, Loss: 5694.24951171875 Step 104, Loss: 7078.431640625 Step 105, Loss: 4312.2626953125 Step 106, Loss: 3901.015869140625 Step 107, Loss: 4590.02978515625 Step 108, Loss: 7480.7421875 Step 109, Loss: 6370.4169921875 Step 110, Loss: 7317.55419921875 Step 111, Loss: 4277.236328125 Step 112, Loss: 7304.1455078125 Step 113, Loss: 12386.59765625 Step 114, Loss: 5439.9638671875 Step 115, Loss: 4104.0244140625 Step 116, Loss: 3760.8095703125 Step 117, Loss: 6188.7431640625 Step 118, Loss: 10398.21875 Step 119, Loss: 6285.6279296875 Step 120, Loss: 6792.962890625 Step 121, Loss: 7379.0625 Step 122, Loss: 8344.27734375 Step 123, Loss: 7244.67529296875 Step 124, Loss: 12012.86328125 Step 125, Loss: 7564.287109375 Step 126, Loss: 3739.274658203125 Step 127, Loss: 4759.7197265625 Step 128, Loss: 4001.96142578125 Step 129, Loss: 4556.54443359375 Step 130, Loss: 9362.6318359375 Step 131, Loss: 6356.1318359375 Step 132, Loss: 4285.30712890625 Step 133, Loss: 7411.22314453125 Step 134, Loss: 10271.453125 Step 135, Loss: 5382.138671875 Step 136, Loss: 8593.34765625 Step 137, Loss: 7603.85986328125 Step 138, Loss: 11660.451171875 Step 139, Loss: 6245.08642578125 Step 140, Loss: 5117.90087890625 Step 141, Loss: 6229.255859375 Step 142, Loss: 7945.4853515625 Step 143, Loss: 6740.63916015625 Step 144, Loss: 5087.9970703125 Step 145, Loss: 9219.767578125 Step 146, Loss: 6910.55419921875 Step 147, Loss: 8740.2607421875 Step 148, Loss: 4533.8310546875 Step 149, Loss: 7340.5283203125 Step 150, Loss: 6660.20458984375 Step 151, Loss: 5197.2548828125 Step 152, Loss: 7767.736328125 Step 153, Loss: 5133.13232421875 Step 154, Loss: 7437.169921875 Step 155, Loss: 7853.3349609375 Step 156, Loss: 9258.6396484375 Step 157, Loss: 7418.76318359375 Step 158, Loss: 6626.232421875 Step 159, Loss: 3714.283447265625 Step 160, Loss: 5961.29248046875 Step 161, Loss: 5662.3291015625 Step 162, Loss: 6956.33642578125 Step 163, Loss: 8190.3017578125 Step 164, Loss: 8912.3173828125 Step 165, Loss: 5106.20068359375 Step 166, Loss: 5819.31298828125 Step 167, Loss: 7879.0009765625 Step 168, Loss: 3761.307861328125 Step 169, Loss: 5512.0107421875 Step 170, Loss: 9456.24609375 Step 171, Loss: 9358.2021484375 Step 172, Loss: 6571.73291015625 Step 173, Loss: 5759.2841796875 Step 174, Loss: 4750.193359375 Step 175, Loss: 3060.482666015625 Step 176, Loss: 3645.848388671875 Step 177, Loss: 5293.1494140625 Step 178, Loss: 6402.9169921875 Step 179, Loss: 6034.236328125 Step 180, Loss: 6270.8154296875 Step 181, Loss: 4819.3349609375 Step 182, Loss: 8022.818359375 Step 183, Loss: 13358.484375 Step 184, Loss: 7847.2744140625 Step 185, Loss: 8682.337890625 Step 186, Loss: 5131.509765625 Step 187, Loss: 4247.0625 Step 188, Loss: 9087.06640625 Step 189, Loss: 8091.3076171875 Step 190, Loss: 4227.57958984375 Step 191, Loss: 6432.40087890625 Step 192, Loss: 5576.962890625 Step 193, Loss: 1886.610595703125 Step 194, Loss: 9776.810546875 Step 195, Loss: 5033.0546875 Step 196, Loss: 9555.078125 Step 197, Loss: 6725.07958984375 Step 198, Loss: 2349.116455078125 Step 199, Loss: 6824.6591796875 Step 200, Loss: 3863.759765625 Step 201, Loss: 5391.271484375 Step 202, Loss: 3628.31787109375 Step 203, Loss: 4978.67431640625 Step 204, Loss: 6335.658203125 Step 205, Loss: 4753.02001953125 Step 206, Loss: 4683.0947265625 Step 207, Loss: 2700.89111328125 Step 208, Loss: 8061.462890625 Step 209, Loss: 8857.8671875 Step 210, Loss: 8114.0439453125 Step 211, Loss: 5542.8154296875 Step 212, Loss: 6195.6015625 Step 213, Loss: 7623.865234375 Step 214, Loss: 6993.47509765625 Step 215, Loss: 5519.9560546875 Step 216, Loss: 6307.55908203125 Step 217, Loss: 7700.26806640625 Step 218, Loss: 5104.650390625 Step 219, Loss: 7564.41845703125 Step 220, Loss: 8365.9931640625 Step 221, Loss: 3689.85302734375 Step 222, Loss: 7112.26416015625 Step 223, Loss: 4709.306640625 Step 224, Loss: 7390.650390625 Step 225, Loss: 5056.04931640625 Step 226, Loss: 8604.79296875 Step 227, Loss: 6731.34716796875 Step 228, Loss: 6824.1826171875 Step 229, Loss: 6225.33837890625 Step 230, Loss: 4657.8828125 Step 231, Loss: 5898.32373046875 Step 232, Loss: 8481.3466796875 Step 233, Loss: 6476.8310546875 Step 234, Loss: 11240.8056640625 Step 235, Loss: 7373.90771484375 Step 236, Loss: 5295.42431640625 Step 237, Loss: 4240.84375 Step 238, Loss: 8358.541015625 Step 239, Loss: 4117.34228515625 Step 240, Loss: 8481.828125 Step 241, Loss: 10369.083984375 Step 242, Loss: 4568.6171875 Step 243, Loss: 4461.9072265625 Step 244, Loss: 7004.130859375 Step 245, Loss: 5408.955078125 Step 246, Loss: 8705.5478515625 Step 247, Loss: 6420.23388671875 Step 248, Loss: 5777.34716796875 Step 249, Loss: 5146.31494140625 Step 250, Loss: 3864.042724609375 Step 251, Loss: 6520.03759765625 Step 252, Loss: 6385.6064453125 Step 253, Loss: 5411.78369140625 Step 254, Loss: 7851.2548828125 Step 255, Loss: 6904.91552734375 Step 256, Loss: 2874.87451171875 Step 257, Loss: 5579.52587890625 Step 258, Loss: 5122.4072265625 Step 259, Loss: 10255.5693359375 Step 260, Loss: 2960.11669921875 Step 261, Loss: 6385.0810546875 Step 262, Loss: 2329.3837890625 Step 263, Loss: 8673.728515625 Step 264, Loss: 5706.6884765625 Step 265, Loss: 6505.4814453125 Step 266, Loss: 10179.833984375 Step 267, Loss: 1975.400634765625 Step 268, Loss: 13108.318359375 Step 269, Loss: 6965.330078125 Step 270, Loss: 5480.1484375 Step 271, Loss: 4018.4619140625 Step 272, Loss: 4464.1796875 Step 273, Loss: 8933.349609375 Step 274, Loss: 8485.5166015625 Step 275, Loss: 5496.9736328125 Step 276, Loss: 7384.4560546875 Step 277, Loss: 4863.380859375 Step 278, Loss: 7213.10009765625 Step 279, Loss: 5356.1591796875 Step 280, Loss: 7484.4326171875 Step 281, Loss: 6442.62255859375 Step 282, Loss: 8255.208984375 Step 283, Loss: 6983.66064453125 Step 284, Loss: 4387.44921875 Step 285, Loss: 7950.56689453125 Step 286, Loss: 4250.01953125 Step 287, Loss: 6698.52392578125 Step 288, Loss: 4678.005859375 Step 289, Loss: 5268.7900390625 Step 290, Loss: 7555.23291015625 Step 291, Loss: 7431.2294921875 Step 292, Loss: 1926.84765625 Step 293, Loss: 7947.033203125 Step 294, Loss: 7582.3505859375 Step 295, Loss: 4912.58935546875 Step 296, Loss: 3210.071044921875 Step 297, Loss: 4715.45947265625 Step 298, Loss: 11928.0234375 Step 299, Loss: 2795.9111328125 Step 300, Loss: 2382.36865234375 Step 301, Loss: 3922.52587890625 Step 302, Loss: 6586.05859375 Step 303, Loss: 4927.994140625 Step 304, Loss: 4698.634765625 Step 305, Loss: 5751.0263671875 Step 306, Loss: 4140.193359375 Step 307, Loss: 5998.912109375 Step 308, Loss: 8095.69873046875 Step 309, Loss: 4656.4951171875 Step 310, Loss: 5908.8486328125 Step 311, Loss: 5935.57275390625 Step 312, Loss: 6386.28125 Step 313, Loss: 6940.20068359375 Step 314, Loss: 9142.6171875 Step 315, Loss: 2590.2939453125 Step 316, Loss: 6724.3203125 Step 317, Loss: 3780.9267578125 Step 318, Loss: 5891.82373046875 Step 319, Loss: 5028.5029296875 Step 320, Loss: 5054.27783203125 Step 321, Loss: 9260.90625 Step 322, Loss: 4988.005859375 Step 323, Loss: 4460.4697265625 Step 324, Loss: 5937.76708984375 Step 325, Loss: 5603.0947265625 Step 326, Loss: 5106.73291015625 Step 327, Loss: 6049.591796875 Step 328, Loss: 7533.95947265625 Step 329, Loss: 3525.634765625 Step 330, Loss: 5791.76416015625 Step 331, Loss: 8983.359375 Step 332, Loss: 4312.0771484375 Step 333, Loss: 10512.8359375 Step 334, Loss: 4606.330078125 Step 335, Loss: 3263.814453125 Step 336, Loss: 4092.023681640625 Step 337, Loss: 6944.146484375 Step 338, Loss: 4410.89111328125 Step 339, Loss: 7504.6416015625 Step 340, Loss: 3232.72412109375 Step 341, Loss: 5564.2734375 Step 342, Loss: 8555.3056640625 Step 343, Loss: 6968.38037109375 Step 344, Loss: 9392.2421875 Step 345, Loss: 4656.8740234375 Step 346, Loss: 5977.08203125 Step 347, Loss: 3161.54931640625 Step 348, Loss: 8505.0556640625 Step 349, Loss: 4393.68896484375 Step 350, Loss: 6235.12353515625 Step 351, Loss: 5617.68798828125 Step 352, Loss: 5996.26416015625 Step 353, Loss: 3228.28857421875 Step 354, Loss: 3934.18798828125 Step 355, Loss: 6701.453125 Step 356, Loss: 3992.371826171875 Step 357, Loss: 4660.65625 Step 358, Loss: 4088.6162109375 Step 359, Loss: 7012.5205078125 Step 360, Loss: 5136.60400390625 Step 361, Loss: 5305.50244140625 Step 362, Loss: 5424.4150390625 Step 363, Loss: 7307.763671875 Step 364, Loss: 7429.21875 Step 365, Loss: 6045.5986328125 Step 366, Loss: 4839.5751953125 Step 367, Loss: 4386.4248046875 Step 368, Loss: 3612.373046875 Step 369, Loss: 4420.654296875 Step 370, Loss: 1770.496337890625 Step 371, Loss: 4422.84326171875 Step 372, Loss: 4809.2705078125 Step 373, Loss: 4419.31494140625 Step 374, Loss: 6618.95068359375 Step 375, Loss: 5547.0146484375 Step 376, Loss: 6225.068359375 Step 377, Loss: 6375.6171875 Step 378, Loss: 2733.548583984375 Step 379, Loss: 5085.37646484375 Step 380, Loss: 1482.4791259765625 Step 381, Loss: 5116.40625 Step 382, Loss: 7516.4189453125 Step 383, Loss: 7153.5146484375 Step 384, Loss: 5548.19921875 Step 385, Loss: 6198.091796875 Step 386, Loss: 2885.35498046875 Step 387, Loss: 9088.642578125 Step 388, Loss: 10132.908203125 Step 389, Loss: 6325.3017578125 Step 390, Loss: 6302.7333984375 Step 391, Loss: 9528.0703125 Step 392, Loss: 4928.8505859375 Step 393, Loss: 6787.09619140625 Step 394, Loss: 5132.1337890625 Step 395, Loss: 7338.17529296875 Step 396, Loss: 4688.7548828125 Step 397, Loss: 5553.365234375 Step 398, Loss: 5271.1220703125 Step 399, Loss: 7157.4794921875 Step 400, Loss: 6882.8994140625 Step 401, Loss: 3169.628173828125 Step 402, Loss: 5126.3876953125 Step 403, Loss: 2943.015625 Step 404, Loss: 6440.2275390625 Step 405, Loss: 6651.93359375 Step 406, Loss: 4473.77001953125 Step 407, Loss: 7756.29296875 Step 408, Loss: 5183.6806640625 Step 409, Loss: 6820.4404296875 Step 410, Loss: 5539.83642578125 Step 411, Loss: 3924.71337890625 Step 412, Loss: 8651.986328125 Step 413, Loss: 6247.2685546875 Step 414, Loss: 8581.720703125 Step 415, Loss: 4809.62744140625 Step 416, Loss: 9160.2060546875 Step 417, Loss: 7738.10986328125 Step 418, Loss: 6360.07958984375 Step 419, Loss: 6054.931640625 Step 420, Loss: 7206.0810546875 Step 421, Loss: 5747.8359375 Step 422, Loss: 4190.154296875 Step 423, Loss: 8808.7548828125 Step 424, Loss: 10492.1015625 Step 425, Loss: 5053.8681640625 Step 426, Loss: 7452.0703125 Step 427, Loss: 4237.7578125 Step 428, Loss: 2869.2587890625 Step 429, Loss: 4189.037109375 Step 430, Loss: 5199.412109375 Step 431, Loss: 6138.0419921875 Step 432, Loss: 5768.3564453125 Step 433, Loss: 4534.21533203125 Step 434, Loss: 9090.6171875 Step 435, Loss: 6423.3310546875 Step 436, Loss: 7451.03369140625 Step 437, Loss: 4777.95654296875 Step 438, Loss: 6288.56494140625 Step 439, Loss: 7394.0986328125 Step 440, Loss: 4003.29052734375 Step 441, Loss: 6014.408203125 Step 442, Loss: 5068.107421875 Step 443, Loss: 5231.33984375 Step 444, Loss: 4120.45263671875 Step 445, Loss: 3309.422607421875 Step 446, Loss: 6272.455078125 Step 447, Loss: 7732.92919921875 Step 448, Loss: 4990.6318359375 Step 449, Loss: 3847.4287109375 Step 450, Loss: 9725.22265625 Step 451, Loss: 4568.7822265625 Step 452, Loss: 6251.55078125 Step 453, Loss: 4916.6533203125 Step 454, Loss: 6205.8486328125 Step 455, Loss: 3742.043701171875 Step 456, Loss: 6726.53515625 Step 457, Loss: 7195.6337890625 Step 458, Loss: 3889.393798828125 Step 459, Loss: 6712.830078125 Step 460, Loss: 6117.98828125 Step 461, Loss: 4891.23681640625 Step 462, Loss: 4433.759765625 Step 463, Loss: 7305.0263671875 Step 464, Loss: 5478.650390625 Step 465, Loss: 5166.064453125 Step 466, Loss: 3683.619873046875 Step 467, Loss: 5027.83056640625 Step 468, Loss: 4194.958984375 Step 469, Loss: 5198.7314453125 Step 470, Loss: 2188.5341796875 Step 471, Loss: 4224.28759765625 Step 472, Loss: 6546.12890625 Step 473, Loss: 8023.783203125 Step 474, Loss: 5774.46875 Step 475, Loss: 11540.9658203125 Step 476, Loss: 4529.5849609375 Step 477, Loss: 7224.9892578125 Step 478, Loss: 7143.7255859375 Step 479, Loss: 4003.281005859375 Step 480, Loss: 4674.6953125 Step 481, Loss: 4016.17138671875 Step 482, Loss: 5501.24169921875 Step 483, Loss: 6691.3173828125 Step 484, Loss: 5257.6669921875 Step 485, Loss: 4468.6845703125 Step 486, Loss: 6899.38671875 Step 487, Loss: 7834.3291015625 Step 488, Loss: 2152.55078125 Step 489, Loss: 3332.0595703125 Step 490, Loss: 5432.7578125 Step 491, Loss: 5623.12939453125 Step 492, Loss: 4124.57568359375 Step 493, Loss: 2991.28955078125 Step 494, Loss: 2484.10107421875 Step 495, Loss: 3913.890869140625 Step 496, Loss: 3984.5537109375 Step 497, Loss: 6917.552734375 Step 498, Loss: 5427.06201171875 Step 499, Loss: 6241.8974609375 Step 500, Loss: 6371.95849609375 Step 501, Loss: 6826.9833984375 Step 502, Loss: 4729.8173828125 Step 503, Loss: 8890.791015625 Step 504, Loss: 4798.78515625 Step 505, Loss: 7523.17578125 Step 506, Loss: 6336.71826171875 Step 507, Loss: 5609.0830078125 Step 508, Loss: 7567.03125 Step 509, Loss: 5495.232421875 Step 510, Loss: 8643.080078125 Step 511, Loss: 8627.75 Step 512, Loss: 4278.63134765625 Step 513, Loss: 2802.3037109375 Step 514, Loss: 5261.93017578125 Step 515, Loss: 7480.02734375 Step 516, Loss: 4462.88671875 Step 517, Loss: 4688.5537109375 Step 518, Loss: 5175.33984375 Step 519, Loss: 5336.74609375 Step 520, Loss: 4055.109375 Step 521, Loss: 8050.234375 Step 522, Loss: 4862.365234375 Step 523, Loss: 2215.04345703125 Step 524, Loss: 6717.880859375 Step 525, Loss: 7073.34619140625 Step 526, Loss: 5147.61962890625 Step 527, Loss: 5916.01513671875 Step 528, Loss: 9250.62109375 Step 529, Loss: 4330.779296875 Step 530, Loss: 6731.103515625 Step 531, Loss: 5912.619140625 Step 532, Loss: 6973.14501953125 Step 533, Loss: 5659.2822265625 Step 534, Loss: 6368.8603515625 Step 535, Loss: 2829.69677734375 Step 536, Loss: 4062.171875 Step 537, Loss: 8312.2265625 Step 538, Loss: 7402.71923828125 Step 539, Loss: 3628.8916015625 Step 540, Loss: 8412.5283203125 Step 541, Loss: 8000.80126953125 Step 542, Loss: 7564.3359375 Step 543, Loss: 8388.765625 Step 544, Loss: 6130.25390625 Step 545, Loss: 2406.29541015625 Step 546, Loss: 6641.8291015625 Step 547, Loss: 11002.904296875 Step 548, Loss: 7771.71240234375 Step 549, Loss: 5968.1435546875 Step 550, Loss: 9708.85546875 Step 551, Loss: 9129.2890625 Step 552, Loss: 5903.83251953125 Step 553, Loss: 5543.8310546875 Step 554, Loss: 6419.64892578125 Step 555, Loss: 2232.00830078125 Step 556, Loss: 5184.1640625 Step 557, Loss: 5963.85693359375 Step 558, Loss: 4534.2802734375 Step 559, Loss: 5421.09228515625 Step 560, Loss: 3904.07958984375 Step 561, Loss: 6978.623046875 Step 562, Loss: 1461.27099609375 Step 563, Loss: 12069.4833984375 Step 564, Loss: 7750.2607421875 Step 565, Loss: 7577.3076171875 Step 566, Loss: 3827.16259765625 Step 567, Loss: 5274.05322265625 Step 568, Loss: 7263.1240234375 Step 569, Loss: 6747.3330078125 Step 570, Loss: 4211.787109375 Step 571, Loss: 4571.31787109375 Step 572, Loss: 5386.0546875 Step 573, Loss: 7221.90087890625 Step 574, Loss: 3458.982177734375 Step 575, Loss: 7115.873046875 Step 576, Loss: 8545.783203125 Step 577, Loss: 6167.5146484375 Step 578, Loss: 5334.92333984375 Step 579, Loss: 4170.1953125 Step 580, Loss: 4024.8974609375 Step 581, Loss: 4247.57666015625 Step 582, Loss: 5386.41552734375 Step 583, Loss: 3921.037109375 Step 584, Loss: 5623.51171875 Step 585, Loss: 7213.2265625 Step 586, Loss: 5052.44921875 Step 587, Loss: 4799.7138671875 Step 588, Loss: 9004.5625 Step 589, Loss: 4952.087890625 Step 590, Loss: 6750.4462890625 Step 591, Loss: 5455.04931640625 Step 592, Loss: 7526.740234375 Step 593, Loss: 2873.11083984375 Step 594, Loss: 2919.0390625 Step 595, Loss: 4697.41015625 Step 596, Loss: 4942.58935546875 Step 597, Loss: 3529.548828125 Step 598, Loss: 5228.578125 Step 599, Loss: 5086.94970703125 Step 600, Loss: 6321.24951171875 Step 601, Loss: 2977.922119140625 Step 602, Loss: 4450.2197265625 Step 603, Loss: 4080.992431640625 Step 604, Loss: 4699.5341796875 Step 605, Loss: 2120.114990234375 Step 606, Loss: 4223.296875 Step 607, Loss: 4961.94775390625 Step 608, Loss: 7770.8603515625 Step 609, Loss: 9213.3515625 Step 610, Loss: 4143.97509765625 Step 611, Loss: 3693.976806640625 Step 612, Loss: 4183.45361328125 Step 613, Loss: 6152.4052734375 Step 614, Loss: 6194.6533203125 Step 615, Loss: 6285.8759765625 Step 616, Loss: 5739.7294921875 Step 617, Loss: 7159.712890625 Step 618, Loss: 3027.54638671875 Step 619, Loss: 6115.3916015625 Step 620, Loss: 3153.59716796875 Step 621, Loss: 5165.57763671875 Step 622, Loss: 4092.1318359375 Step 623, Loss: 6564.84033203125 Step 624, Loss: 8872.041015625 Step 625, Loss: 4565.2060546875
# torch.save(cm_unet.state_dict(), '/content/drive/MyDrive/cv_model/cd_model.pth')
Снова сэмплируем¶
Обратим внимание, что тут мы сэмпилруем без гайденса, потому что мы его уже частично прокинули в модель, когда делали шаг учителя с CFG.
Снова для референса приводим картинки на этом этапе:

Ваши картинки не обязаны совпадать: у вас могут быть немного менее/более качественные. Небольшая разница по качеству на оценку не влиет.
# Подставляем нашу новую обученную модель в пайплайн
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'cd'
generator = torch.Generator(device="cuda").manual_seed(0)
guidance_scale = 0
images = consistency_sampling(
pipe,
prompt,
num_inference_steps=4,
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale
)
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
Давайте посмотрим на картинки для других промптов¶
validation_prompts = [
"A sad puppy with large eyes",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
"A girl with pale blue hair and a cami tank top",
"A lighthouse in a giant wave, origami style",
"belle epoque, christmas, red house in the forest, photo realistic, 8k",
"A small cactus with a happy face in the Sahara desert",
"Green commercial building with refrigerator and refrigeration units outside",
]
for prompt in validation_prompts:
generator = torch.Generator(device="cuda").manual_seed(0)
images = consistency_sampling(
pipe,
prompt,
num_inference_steps=4,
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale
)
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
Multi-boundary Сonsistency Distillation¶
В конце мы рассмотрим недавнюю модификацию CD, Multi-boundary CD, где интегрируем не всю траекторию сразу и потом сэмплируем с возвращением назад, а разбиваем траектории на $K$ отрезков и применяет CD внутри каждого отрезка независимо. Например, на картинке выше у нас два отрезка: зеленым и красным выделены две граничные точки. Для классического CD, рассмотренного ранее, у нас только одна граничная точка в $t = 0$
Обратим внимание, что сэмплирование становится детерминистичным и можно снова использовать DDIM солвер, где число шагов равно числу интервалов $K$, на которые мы разбили траектории во время обучения.
Этот метод гораздо лучше работает чем обычный CD, потому что решать задачу CD на отрезках, а не на всей траектории, гораздо проще. В текущем задании мы разобьем траекторию на $K=4$ отрезка.
Подробнее почитать можно в этой статье.
Задание №7 (0.5 балла, сдается в контесте)¶
Ниже реализуйте функцию, которая для $K=4$ отрезков будет сопоставлять таймстепам соответствующие граничные точки.
Например, для $K=2$ отрезков граничные точки будут: [0, 499]
$0 \leq t < 499$ -> граничная точка - $0$
$499 \leq t < 999$ -> граничная точка - $499$
Ресурсы в колабе закончились, обучал последнюю модель в kaggle¶
# Восстанавливаем модель UNet и оборачиваем её PEFT
unet = UNet2DConditionModel.from_pretrained(
'sd-legacy/stable-diffusion-v1-5',
subfolder='unet',
torch_dtype=torch.float32,
).to('cuda')
cm_unet = get_peft_model(unet, lora_config, adapter_name="multi-cd")
cm_unet.enable_gradient_checkpointing()
# Создаем новый адаптер
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32
# Добавляем новый адаптер "multi-cd"
cm_unet.add_adapter("multi-cd", lora_config)
cm_unet.set_adapter("multi-cd")
# Пересоздаём оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
def get_multi_boundary_timesteps(
timesteps,
num_boundaries=4,
num_timesteps=1000,
):
"""
Для батча таймстепов определяем соответствующие граничные точки.
params:
timesteps: torch.Tensor(batch_size, device='cuda')
returns:
boundary_timesteps: torch.Tensor(batch_size, device='cuda')
"""
# Здесь важно аккуратно поработать с таймстепами,
# чтобы не перелетать граничные точки и при этом иногда попадать в них.
# Совет: повыводить timesteps и boundary_timesteps перед обучением.
boundaries = torch.linspace(0, num_timesteps - 1, num_boundaries + 1, device=timesteps.device, dtype=torch.long)
indices = torch.bucketize(timesteps, boundaries, right=False)
boundary_timesteps = boundaries[torch.clamp(indices - 1, min=0)]
return boundary_timesteps
multi_cd_loss = functools.partial(
cm_loss_template,
loss_fn=pseudo_huber_loss,
get_boundary_timesteps=get_multi_boundary_timesteps,
get_xs_from_xt=get_xs_from_xt_with_teacher
)
assert cm_unet.active_adapter == 'multi-cd'
Теперь обучим Multi-boundary CD модель
torch.cuda.empty_cache()
num_grad_accum = 2 # обновляем параметры каждые 2 шага
train_loop(cm_unet, pipe, train_dataloader, optimizer, multi_cd_loss, num_grad_accum)
/tmp/ipykernel_23/4028741595.py:7: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = GradScaler()
0%| | 0/625 [00:00<?, ?it/s]
/tmp/ipykernel_23/4028741595.py:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with autocast(dtype=torch.float16): # Включаем mixed precision
/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
Step 1, Loss: 1412.8082275390625 Step 2, Loss: 941.15087890625 Step 3, Loss: 1324.7342529296875 Step 4, Loss: 1476.2618408203125 Step 5, Loss: 1478.176513671875 Step 6, Loss: 1110.04052734375 Step 7, Loss: 1325.424560546875 Step 8, Loss: 1190.1014404296875 Step 9, Loss: 1061.3260498046875 Step 10, Loss: 1054.9925537109375 Step 11, Loss: 1480.53125 Step 12, Loss: 1221.750732421875 Step 13, Loss: 1454.07177734375 Step 14, Loss: 1729.9173583984375 Step 15, Loss: 1381.732666015625 Step 16, Loss: 1263.1090087890625 Step 17, Loss: 1143.2685546875 Step 18, Loss: 981.5584106445312 Step 19, Loss: 1296.2276611328125 Step 20, Loss: 1196.6866455078125 Step 21, Loss: 1264.32666015625 Step 22, Loss: 1215.798583984375 Step 23, Loss: 1191.38671875 Step 24, Loss: 1388.1199951171875 Step 25, Loss: 1487.18896484375 Step 26, Loss: 1353.6796875 Step 27, Loss: 1160.75732421875 Step 28, Loss: 1861.5467529296875 Step 29, Loss: 1208.218017578125 Step 30, Loss: 1362.111572265625 Step 31, Loss: 1071.198974609375 Step 32, Loss: 1252.173095703125 Step 33, Loss: 1355.49072265625 Step 34, Loss: 1508.05126953125 Step 35, Loss: 1267.271240234375 Step 36, Loss: 1185.5947265625 Step 37, Loss: 1482.4945068359375 Step 38, Loss: 1181.853759765625 Step 39, Loss: 1394.738525390625 Step 40, Loss: 1645.787109375 Step 41, Loss: 1171.0008544921875 Step 42, Loss: 1287.83935546875 Step 43, Loss: 993.2245483398438 Step 44, Loss: 1151.0255126953125 Step 45, Loss: 1318.048583984375 Step 46, Loss: 924.6395263671875 Step 47, Loss: 1511.181396484375 Step 48, Loss: 1202.577880859375 Step 49, Loss: 1508.792236328125 Step 50, Loss: 1505.3642578125 Step 51, Loss: 1560.315673828125 Step 52, Loss: 1317.7265625 Step 53, Loss: 1494.878662109375 Step 54, Loss: 1350.65966796875 Step 55, Loss: 1119.683837890625 Step 56, Loss: 1108.71435546875 Step 57, Loss: 1348.2799072265625 Step 58, Loss: 1156.2025146484375 Step 59, Loss: 1242.4732666015625 Step 60, Loss: 1325.41650390625 Step 61, Loss: 1322.845703125 Step 62, Loss: 1502.64111328125 Step 63, Loss: 1255.45654296875 Step 64, Loss: 1404.2796630859375 Step 65, Loss: 1595.4014892578125 Step 66, Loss: 1242.0611572265625 Step 67, Loss: 1469.566162109375 Step 68, Loss: 1601.4737548828125 Step 69, Loss: 1300.706787109375 Step 70, Loss: 1010.876220703125 Step 71, Loss: 1304.723876953125 Step 72, Loss: 1525.03662109375 Step 73, Loss: 973.072265625 Step 74, Loss: 1136.6474609375 Step 75, Loss: 1350.9725341796875 Step 76, Loss: 1347.3695068359375 Step 77, Loss: 1024.8515625 Step 78, Loss: 1117.653076171875 Step 79, Loss: 1313.831787109375 Step 80, Loss: 1625.9874267578125 Step 81, Loss: 1471.9322509765625 Step 82, Loss: 2217.120361328125 Step 83, Loss: 1050.25341796875 Step 84, Loss: 1215.480224609375 Step 85, Loss: 1341.806640625 Step 86, Loss: 1674.99072265625 Step 87, Loss: 1209.512451171875 Step 88, Loss: 1633.779296875 Step 89, Loss: 1832.9788818359375 Step 90, Loss: 1328.9354248046875 Step 91, Loss: 1182.6014404296875 Step 92, Loss: 1966.3428955078125 Step 93, Loss: 1258.536865234375 Step 94, Loss: 1890.312744140625 Step 95, Loss: 1252.579345703125 Step 96, Loss: 1126.8079833984375 Step 97, Loss: 1239.4892578125 Step 98, Loss: 1986.4447021484375 Step 99, Loss: 1606.381591796875 Step 100, Loss: 1837.0020751953125 Step 101, Loss: 1673.17578125 Step 102, Loss: 1055.4329833984375 Step 103, Loss: 1901.669189453125 Step 104, Loss: 1321.781494140625 Step 105, Loss: 1454.352294921875 Step 106, Loss: 1967.356201171875 Step 107, Loss: 1142.036376953125 Step 108, Loss: 1747.320556640625 Step 109, Loss: 1317.83447265625 Step 110, Loss: 1302.368408203125 Step 111, Loss: 2881.3740234375 Step 112, Loss: 1421.568603515625 Step 113, Loss: 1219.75 Step 114, Loss: 1560.2093505859375 Step 115, Loss: 1460.766845703125 Step 116, Loss: 1280.051513671875 Step 117, Loss: 1520.0809326171875 Step 118, Loss: 1670.777587890625 Step 119, Loss: 2165.53662109375 Step 120, Loss: 1461.3638916015625 Step 121, Loss: 1569.815185546875 Step 122, Loss: 1206.189697265625 Step 123, Loss: 1708.3341064453125 Step 124, Loss: 1029.6060791015625 Step 125, Loss: 2053.216796875 Step 126, Loss: 1506.0224609375 Step 127, Loss: 1525.813720703125 Step 128, Loss: 1918.3447265625 Step 129, Loss: 1060.704345703125 Step 130, Loss: 1065.752197265625 Step 131, Loss: 2078.143310546875 Step 132, Loss: 2007.9171142578125 Step 133, Loss: 1194.0960693359375 Step 134, Loss: 1399.7044677734375 Step 135, Loss: 1465.194580078125 Step 136, Loss: 1657.65478515625 Step 137, Loss: 1398.98779296875 Step 138, Loss: 1893.5042724609375 Step 139, Loss: 1590.546630859375 Step 140, Loss: 1150.5404052734375 Step 141, Loss: 1262.00341796875 Step 142, Loss: 1475.8515625 Step 143, Loss: 1419.84619140625 Step 144, Loss: 1982.8699951171875 Step 145, Loss: 1801.156494140625 Step 146, Loss: 1469.816650390625 Step 147, Loss: 1709.6822509765625 Step 148, Loss: 2077.054443359375 Step 149, Loss: 1249.806396484375 Step 150, Loss: 1342.635009765625 Step 151, Loss: 1223.06884765625 Step 152, Loss: 1386.104248046875 Step 153, Loss: 1990.470458984375 Step 154, Loss: 1389.048095703125 Step 155, Loss: 1889.6376953125 Step 156, Loss: 2174.4853515625 Step 157, Loss: 1561.5816650390625 Step 158, Loss: 1745.8248291015625 Step 159, Loss: 1040.3082275390625 Step 160, Loss: 1628.5341796875 Step 161, Loss: 1864.4423828125 Step 162, Loss: 1203.421142578125 Step 163, Loss: 1344.75390625 Step 164, Loss: 1595.09033203125 Step 165, Loss: 1238.3028564453125 Step 166, Loss: 1102.130859375 Step 167, Loss: 1766.8060302734375 Step 168, Loss: 1282.4505615234375 Step 169, Loss: 1372.6455078125 Step 170, Loss: 1287.114501953125 Step 171, Loss: 2015.166015625 Step 172, Loss: 1478.1759033203125 Step 173, Loss: 1425.34765625 Step 174, Loss: 1295.652587890625 Step 175, Loss: 1922.87060546875 Step 176, Loss: 1617.073974609375 Step 177, Loss: 1533.184326171875 Step 178, Loss: 1334.359619140625 Step 179, Loss: 1599.361328125 Step 180, Loss: 1004.377197265625 Step 181, Loss: 1357.1324462890625 Step 182, Loss: 1589.1905517578125 Step 183, Loss: 1594.9830322265625 Step 184, Loss: 1796.361328125 Step 185, Loss: 1063.1932373046875 Step 186, Loss: 1612.2103271484375 Step 187, Loss: 1179.7874755859375 Step 188, Loss: 1238.40380859375 Step 189, Loss: 1625.638916015625 Step 190, Loss: 1425.85595703125 Step 191, Loss: 1521.1871337890625 Step 192, Loss: 1244.6064453125 Step 193, Loss: 1631.3555908203125 Step 194, Loss: 1567.26611328125 Step 195, Loss: 1530.6263427734375 Step 196, Loss: 1284.45361328125 Step 197, Loss: 1493.089599609375 Step 198, Loss: 1071.772216796875 Step 199, Loss: 1706.98193359375 Step 200, Loss: 1632.34326171875 Step 201, Loss: 2103.3896484375 Step 202, Loss: 954.9825439453125 Step 203, Loss: 1323.804443359375 Step 204, Loss: 1528.763916015625 Step 205, Loss: 2031.8448486328125 Step 206, Loss: 1118.1773681640625 Step 207, Loss: 1317.7890625 Step 208, Loss: 1503.223388671875 Step 209, Loss: 1808.2283935546875 Step 210, Loss: 1703.98974609375 Step 211, Loss: 1049.0858154296875 Step 212, Loss: 1275.987060546875 Step 213, Loss: 1196.9306640625 Step 214, Loss: 1411.6932373046875 Step 215, Loss: 1246.766357421875 Step 216, Loss: 987.7606811523438 Step 217, Loss: 1669.424560546875 Step 218, Loss: 1411.1510009765625 Step 219, Loss: 1460.197998046875 Step 220, Loss: 1075.899169921875 Step 221, Loss: 1218.0361328125 Step 222, Loss: 1453.509033203125 Step 223, Loss: 1131.749267578125 Step 224, Loss: 1093.9444580078125 Step 225, Loss: 1427.131591796875 Step 226, Loss: 1485.5362548828125 Step 227, Loss: 1257.9765625 Step 228, Loss: 831.0469970703125 Step 229, Loss: 1500.1734619140625 Step 230, Loss: 883.95947265625 Step 231, Loss: 1360.126708984375 Step 232, Loss: 1227.45166015625 Step 233, Loss: 1387.870849609375 Step 234, Loss: 2018.3006591796875 Step 235, Loss: 1122.8701171875 Step 236, Loss: 1076.186279296875 Step 237, Loss: 1436.813720703125 Step 238, Loss: 1551.4228515625 Step 239, Loss: 1010.64208984375 Step 240, Loss: 1146.5865478515625 Step 241, Loss: 1784.942138671875 Step 242, Loss: 1605.85986328125 Step 243, Loss: 1865.7120361328125 Step 244, Loss: 1131.7679443359375 Step 245, Loss: 1586.39501953125 Step 246, Loss: 1053.21240234375 Step 247, Loss: 1830.83203125 Step 248, Loss: 1143.00146484375 Step 249, Loss: 1098.12890625 Step 250, Loss: 1623.798583984375 Step 251, Loss: 931.1444702148438 Step 252, Loss: 1191.5740966796875 Step 253, Loss: 1232.3299560546875 Step 254, Loss: 1454.0263671875 Step 255, Loss: 1281.9136962890625 Step 256, Loss: 1520.6201171875 Step 257, Loss: 1286.822998046875 Step 258, Loss: 1148.400146484375 Step 259, Loss: 1985.4564208984375 Step 260, Loss: 1180.51220703125 Step 261, Loss: 1593.327880859375 Step 262, Loss: 1154.2291259765625 Step 263, Loss: 1704.37353515625 Step 264, Loss: 1198.2415771484375 Step 265, Loss: 979.8606567382812 Step 266, Loss: 966.625732421875 Step 267, Loss: 1343.4063720703125 Step 268, Loss: 1128.93701171875 Step 269, Loss: 1148.60986328125 Step 270, Loss: 1303.980224609375 Step 271, Loss: 1707.3887939453125 Step 272, Loss: 1212.033203125 Step 273, Loss: 1485.076416015625 Step 274, Loss: 802.6442260742188 Step 275, Loss: 1658.125 Step 276, Loss: 1155.947265625 Step 277, Loss: 1341.27978515625 Step 278, Loss: 1107.769775390625 Step 279, Loss: 1477.561767578125 Step 280, Loss: 1592.630126953125 Step 281, Loss: 1608.114501953125 Step 282, Loss: 1213.718017578125 Step 283, Loss: 1298.125 Step 284, Loss: 1297.184326171875 Step 285, Loss: 1524.143798828125 Step 286, Loss: 2174.5380859375 Step 287, Loss: 1597.50244140625 Step 288, Loss: 959.2469482421875 Step 289, Loss: 1169.5859375 Step 290, Loss: 1442.810302734375 Step 291, Loss: 1034.099853515625 Step 292, Loss: 998.4970703125 Step 293, Loss: 1627.68798828125 Step 294, Loss: 1459.016845703125 Step 295, Loss: 1093.124267578125 Step 296, Loss: 1057.6783447265625 Step 297, Loss: 1432.466796875 Step 298, Loss: 1088.4395751953125 Step 299, Loss: 1693.2427978515625 Step 300, Loss: 1282.32421875 Step 301, Loss: 920.5098876953125 Step 302, Loss: 1827.253662109375 Step 303, Loss: 1689.7030029296875 Step 304, Loss: 1525.353515625 Step 305, Loss: 1677.1201171875 Step 306, Loss: 1784.0885009765625 Step 307, Loss: 1434.113037109375 Step 308, Loss: 1304.2841796875 Step 309, Loss: 1061.751953125 Step 310, Loss: 1253.636474609375 Step 311, Loss: 1144.22802734375 Step 312, Loss: 1889.0247802734375 Step 313, Loss: 1265.8465576171875 Step 314, Loss: 1193.4488525390625 Step 315, Loss: 1813.67578125 Step 316, Loss: 1387.3717041015625 Step 317, Loss: 1652.5 Step 318, Loss: 2130.800537109375 Step 319, Loss: 1272.6710205078125 Step 320, Loss: 1232.8660888671875 Step 321, Loss: 1048.517822265625 Step 322, Loss: 1738.4168701171875 Step 323, Loss: 994.63671875 Step 324, Loss: 1520.3111572265625 Step 325, Loss: 1712.8887939453125 Step 326, Loss: 1547.9693603515625 Step 327, Loss: 1156.1005859375 Step 328, Loss: 1533.877197265625 Step 329, Loss: 1518.6517333984375 Step 330, Loss: 1046.860595703125 Step 331, Loss: 1764.745361328125 Step 332, Loss: 1499.9609375 Step 333, Loss: 1422.4716796875 Step 334, Loss: 1091.22998046875 Step 335, Loss: 873.0045166015625 Step 336, Loss: 1054.060302734375 Step 337, Loss: 1297.4365234375 Step 338, Loss: 1910.373291015625 Step 339, Loss: 1235.56494140625 Step 340, Loss: 1540.74609375 Step 341, Loss: 1384.647705078125 Step 342, Loss: 1600.019775390625 Step 343, Loss: 1365.9439697265625 Step 344, Loss: 1584.5767822265625 Step 345, Loss: 1595.54736328125 Step 346, Loss: 1569.956298828125 Step 347, Loss: 1193.80712890625 Step 348, Loss: 953.3492431640625 Step 349, Loss: 1395.6644287109375 Step 350, Loss: 1400.4874267578125 Step 351, Loss: 1451.081298828125 Step 352, Loss: 1007.0923461914062 Step 353, Loss: 1332.080322265625 Step 354, Loss: 1357.7301025390625 Step 355, Loss: 953.5802001953125 Step 356, Loss: 1283.988525390625 Step 357, Loss: 1157.7041015625 Step 358, Loss: 1377.642822265625 Step 359, Loss: 1074.5054931640625 Step 360, Loss: 1783.91455078125 Step 361, Loss: 1339.19677734375 Step 362, Loss: 1765.009765625 Step 363, Loss: 1154.8880615234375 Step 364, Loss: 1303.498291015625 Step 365, Loss: 1372.677734375 Step 366, Loss: 1622.110107421875 Step 367, Loss: 1271.9083251953125 Step 368, Loss: 1147.957763671875 Step 369, Loss: 1974.82080078125 Step 370, Loss: 1554.877197265625 Step 371, Loss: 1458.959228515625 Step 372, Loss: 1624.116455078125 Step 373, Loss: 1260.9439697265625 Step 374, Loss: 1553.80419921875 Step 375, Loss: 1050.49462890625 Step 376, Loss: 1284.7432861328125 Step 377, Loss: 1408.9525146484375 Step 378, Loss: 1273.099609375 Step 379, Loss: 1459.145751953125 Step 380, Loss: 1198.253173828125 Step 381, Loss: 1334.21826171875 Step 382, Loss: 1057.520263671875 Step 383, Loss: 1203.5245361328125 Step 384, Loss: 1140.5009765625 Step 385, Loss: 1093.99462890625 Step 386, Loss: 1212.078857421875 Step 387, Loss: 1388.4498291015625 Step 388, Loss: 1805.5279541015625 Step 389, Loss: 1513.4420166015625 Step 390, Loss: 1375.2115478515625 Step 391, Loss: 1075.574951171875 Step 392, Loss: 1794.258544921875 Step 393, Loss: 1790.400146484375 Step 394, Loss: 1223.4815673828125 Step 395, Loss: 1891.246826171875 Step 396, Loss: 1332.2364501953125 Step 397, Loss: 1290.03759765625 Step 398, Loss: 1758.0888671875 Step 399, Loss: 1058.5726318359375 Step 400, Loss: 912.8321533203125 Step 401, Loss: 1143.2447509765625 Step 402, Loss: 1383.248291015625 Step 403, Loss: 1202.1904296875 Step 404, Loss: 1598.3101806640625 Step 405, Loss: 1495.3203125 Step 406, Loss: 1449.4405517578125 Step 407, Loss: 1325.498779296875 Step 408, Loss: 1168.171875 Step 409, Loss: 830.4091796875 Step 410, Loss: 1674.7080078125 Step 411, Loss: 1399.978759765625 Step 412, Loss: 1705.4149169921875 Step 413, Loss: 2014.07421875 Step 414, Loss: 1577.784912109375 Step 415, Loss: 1314.164306640625 Step 416, Loss: 1251.5391845703125 Step 417, Loss: 1779.771728515625 Step 418, Loss: 1411.2890625 Step 419, Loss: 898.7115478515625 Step 420, Loss: 1513.54833984375 Step 421, Loss: 1109.6455078125 Step 422, Loss: 1550.4925537109375 Step 423, Loss: 1210.365966796875 Step 424, Loss: 1833.0389404296875 Step 425, Loss: 1037.49365234375 Step 426, Loss: 1230.4632568359375 Step 427, Loss: 1338.976318359375 Step 428, Loss: 1424.70947265625 Step 429, Loss: 1671.7451171875 Step 430, Loss: 1829.2230224609375 Step 431, Loss: 1668.832763671875 Step 432, Loss: 1546.8280029296875 Step 433, Loss: 1254.328369140625 Step 434, Loss: 1617.667724609375 Step 435, Loss: 1311.4676513671875 Step 436, Loss: 1078.8048095703125 Step 437, Loss: 1459.23828125 Step 438, Loss: 1391.421630859375 Step 439, Loss: 1024.3675537109375 Step 440, Loss: 1114.6915283203125 Step 441, Loss: 1341.411865234375 Step 442, Loss: 990.5977783203125 Step 443, Loss: 1176.3507080078125 Step 444, Loss: 1338.45361328125 Step 445, Loss: 1037.4556884765625 Step 446, Loss: 1386.7548828125 Step 447, Loss: 1723.485595703125 Step 448, Loss: 2248.262451171875 Step 449, Loss: 1507.2198486328125 Step 450, Loss: 1902.50732421875 Step 451, Loss: 1609.2783203125 Step 452, Loss: 883.104248046875 Step 453, Loss: 933.4012451171875 Step 454, Loss: 1770.635009765625 Step 455, Loss: 1091.0963134765625 Step 456, Loss: 1338.765625 Step 457, Loss: 1500.919677734375 Step 458, Loss: 1115.61279296875 Step 459, Loss: 1865.16357421875 Step 460, Loss: 1275.708740234375 Step 461, Loss: 1856.6123046875 Step 462, Loss: 1179.163330078125 Step 463, Loss: 785.0582275390625 Step 464, Loss: 1065.76708984375 Step 465, Loss: 1140.35302734375 Step 466, Loss: 1262.082275390625 Step 467, Loss: 1484.341064453125 Step 468, Loss: 1126.927001953125 Step 469, Loss: 1812.801513671875 Step 470, Loss: 1583.437255859375 Step 471, Loss: 1389.092529296875 Step 472, Loss: 973.5128173828125 Step 473, Loss: 2022.6402587890625 Step 474, Loss: 1326.77734375 Step 475, Loss: 1460.857421875 Step 476, Loss: 1146.515869140625 Step 477, Loss: 975.498291015625 Step 478, Loss: 900.0221557617188 Step 479, Loss: 1166.374267578125 Step 480, Loss: 1737.30126953125 Step 481, Loss: 1062.071044921875 Step 482, Loss: 1642.854248046875 Step 483, Loss: 1408.7431640625 Step 484, Loss: 1588.709716796875 Step 485, Loss: 895.8580322265625 Step 486, Loss: 1766.7691650390625 Step 487, Loss: 863.95654296875 Step 488, Loss: 1673.743408203125 Step 489, Loss: 1306.481201171875 Step 490, Loss: 1376.691650390625 Step 491, Loss: 1305.647705078125 Step 492, Loss: 1177.880859375 Step 493, Loss: 1562.1954345703125 Step 494, Loss: 2219.90478515625 Step 495, Loss: 911.30419921875 Step 496, Loss: 1514.813232421875 Step 497, Loss: 1346.8331298828125 Step 498, Loss: 991.4085693359375 Step 499, Loss: 977.5787353515625 Step 500, Loss: 1657.347412109375 Step 501, Loss: 1287.251953125 Step 502, Loss: 743.5289306640625 Step 503, Loss: 1239.245361328125 Step 504, Loss: 1049.7587890625 Step 505, Loss: 1977.7630615234375 Step 506, Loss: 1106.8863525390625 Step 507, Loss: 1745.14404296875 Step 508, Loss: 1097.395751953125 Step 509, Loss: 1811.0892333984375 Step 510, Loss: 1231.067626953125 Step 511, Loss: 1381.961669921875 Step 512, Loss: 1410.78662109375 Step 513, Loss: 1118.81787109375 Step 514, Loss: 1412.028076171875 Step 515, Loss: 1493.8941650390625 Step 516, Loss: 1394.9820556640625 Step 517, Loss: 1060.9759521484375 Step 518, Loss: 960.06982421875 Step 519, Loss: 1316.806884765625 Step 520, Loss: 1471.657958984375 Step 521, Loss: 1293.024658203125 Step 522, Loss: 1042.2998046875 Step 523, Loss: 1133.102294921875 Step 524, Loss: 1363.699951171875 Step 525, Loss: 1608.5966796875 Step 526, Loss: 1173.14794921875 Step 527, Loss: 1066.2161865234375 Step 528, Loss: 1780.4852294921875 Step 529, Loss: 1444.552978515625 Step 530, Loss: 942.1851806640625 Step 531, Loss: 1389.2520751953125 Step 532, Loss: 1434.215087890625 Step 533, Loss: 1865.0098876953125 Step 534, Loss: 1295.721435546875 Step 535, Loss: 1056.942138671875 Step 536, Loss: 1707.695556640625 Step 537, Loss: 1559.7491455078125 Step 538, Loss: 1124.744873046875 Step 539, Loss: 1057.86767578125 Step 540, Loss: 1226.157470703125 Step 541, Loss: 1371.0052490234375 Step 542, Loss: 1402.97998046875 Step 543, Loss: 1217.48876953125 Step 544, Loss: 1207.627685546875 Step 545, Loss: 1065.9107666015625 Step 546, Loss: 1155.632080078125 Step 547, Loss: 1536.2127685546875 Step 548, Loss: 1335.600341796875 Step 549, Loss: 1116.26904296875 Step 550, Loss: 1721.7852783203125 Step 551, Loss: 1162.0531005859375 Step 552, Loss: 1553.541015625 Step 553, Loss: 1483.7789306640625 Step 554, Loss: 1179.0186767578125 Step 555, Loss: 1394.5513916015625 Step 556, Loss: 1367.4727783203125 Step 557, Loss: 1133.940673828125 Step 558, Loss: 945.0025024414062 Step 559, Loss: 1859.411865234375 Step 560, Loss: 1189.91650390625 Step 561, Loss: 1205.9249267578125 Step 562, Loss: 1007.8544311523438 Step 563, Loss: 1241.7750244140625 Step 564, Loss: 1296.9080810546875 Step 565, Loss: 1383.333740234375 Step 566, Loss: 2233.53515625 Step 567, Loss: 1587.458251953125 Step 568, Loss: 1825.3814697265625 Step 569, Loss: 1371.637451171875 Step 570, Loss: 1573.197265625 Step 571, Loss: 1047.68017578125 Step 572, Loss: 907.6341552734375 Step 573, Loss: 1212.4932861328125 Step 574, Loss: 2112.6806640625 Step 575, Loss: 950.9673461914062 Step 576, Loss: 1824.330810546875 Step 577, Loss: 1359.4169921875 Step 578, Loss: 1560.9697265625 Step 579, Loss: 1049.075439453125 Step 580, Loss: 1353.3956298828125 Step 581, Loss: 1516.19921875 Step 582, Loss: 1597.74951171875 Step 583, Loss: 1127.2060546875 Step 584, Loss: 1558.127685546875 Step 585, Loss: 1831.9423828125 Step 586, Loss: 1555.7625732421875 Step 587, Loss: 1547.912109375 Step 588, Loss: 1655.69140625 Step 589, Loss: 1040.394287109375 Step 590, Loss: 1128.268310546875 Step 591, Loss: 1267.056884765625 Step 592, Loss: 1277.0677490234375 Step 593, Loss: 1127.994384765625 Step 594, Loss: 1245.7625732421875 Step 595, Loss: 1278.789306640625 Step 596, Loss: 2345.984619140625 Step 597, Loss: 2026.3302001953125 Step 598, Loss: 1655.1656494140625 Step 599, Loss: 1053.072265625 Step 600, Loss: 1770.6749267578125 Step 601, Loss: 992.5887451171875 Step 602, Loss: 1830.2685546875 Step 603, Loss: 1374.114990234375 Step 604, Loss: 1073.7449951171875 Step 605, Loss: 1183.1143798828125 Step 606, Loss: 1291.30322265625 Step 607, Loss: 1838.1009521484375 Step 608, Loss: 1740.20556640625 Step 609, Loss: 937.88232421875 Step 610, Loss: 1357.314697265625 Step 611, Loss: 1289.94921875 Step 612, Loss: 1513.4420166015625 Step 613, Loss: 1593.559814453125 Step 614, Loss: 1200.7783203125 Step 615, Loss: 1296.7430419921875 Step 616, Loss: 1384.037109375 Step 617, Loss: 995.932861328125 Step 618, Loss: 1321.1334228515625 Step 619, Loss: 1980.3076171875 Step 620, Loss: 1491.490966796875 Step 621, Loss: 1263.7740478515625 Step 622, Loss: 1492.602783203125 Step 623, Loss: 1138.0784912109375 Step 624, Loss: 1257.24072265625 Step 625, Loss: 1708.15283203125
# torch.save(cm_unet.state_dict(), 'mbcd_model.pth')
И в последний раз сэмплируем¶
Важно: теперь у нас появляется возможно сэмплировать детерминистично с помощью оригинального солверва DDIM за 4 шага. Так что возвращаем сэмплирование исходным pipe-ом.
Ниже прикрепляем референс и напомним, что у вас картинки могут отличаться и быть чуть хуже/лучше.

pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'multi-cd'
guidance_scale = 1
for prompt in validation_prompts:
generator = torch.Generator(device="cuda").manual_seed(1)
images = pipe(
prompt=prompt,
num_inference_steps=4,
num_images_per_prompt=4,
generator=generator,
guidance_scale=guidance_scale,
).images
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
Возвращаемся в колаб к остальным адаптерам¶
# Загрузка предыдущих адаптеров
load = True
if load:
cm_unet = get_peft_model(unet, lora_config, adapter_name="ct")
cm_unet.enable_gradient_checkpointing()
cm_unet.load_state_dict(torch.load('/content/drive/MyDrive/cv_model/cm_model.pth'))
# Добавляем адаптер "cd" и загружаем его параметры
cm_unet.add_adapter("cd", lora_config)
cm_unet.set_adapter("cd")
cm_unet.load_state_dict(torch.load('/content/drive/MyDrive/cv_model/cd_model.pth'))
# Добавляем адаптер "multi-cd" и загружаем его параметры
cm_unet.add_adapter("multi-cd", lora_config)
cm_unet.set_adapter("multi-cd")
cm_unet.load_state_dict(torch.load('/content/drive/MyDrive/cv_model/mbcd_model.pth'))
Задание №8¶
Все, что осталось сделать - это загрузить ваши обученные модельки на huggingface_hub. Это очень популярный и удобный способ для хранения моделей, которые легко можно загружать и подставлять в модель. Другими словами GitHub для моделей и датасетов.
Создайте аккаунт на huggingface.co
Получите свой HF токен, который можно получить здесь: https://huggingface.co/settings/tokens
Создайте репозиторий для ваших моделями https://huggingface.co/new
Важно: перед отправкой нотбука на проверку, не забудьте удалить свой hf токен!
cm_unet.push_to_hub(
"eteron/cv_week_final_model", # "<username>/<repo-name>"
token=my_token
)
README.md: 0%| | 0.00/31.0 [00:00<?, ?B/s]
adapter_model.safetensors: 0%| | 0.00/538M [00:00<?, ?B/s]
adapter_model.safetensors: 0%| | 0.00/269M [00:00<?, ?B/s]
Upload 3 LFS files: 0%| | 0/3 [00:00<?, ?it/s]
adapter_model.safetensors: 0%| | 0.00/269M [00:00<?, ?B/s]
CommitInfo(commit_url='https://huggingface.co/eteron/cv_week_final_model/commit/7415ad7aa7101a82e3c0ab56c3bdc325330e8c80', commit_message='Upload model', commit_description='', oid='7415ad7aa7101a82e3c0ab56c3bdc325330e8c80', pr_url=None, repo_url=RepoUrl('https://huggingface.co/eteron/cv_week_final_model', endpoint='https://huggingface.co', repo_type='model', repo_id='eteron/cv_week_final_model'), pr_revision=None, pr_num=None)
Пример, как должен выглядеть результат выполнения команды: https://huggingface.co/dbaranchuk/cv-week-final-task-example
Давайте проверим, что загрузка модели корректно работает.
from peft import PeftModel
loaded_cm_unet = PeftModel.from_pretrained(
unet,
"eteron/cv_week_final_model",
token=my_token,
subfolder='multi-cd',
adapter_name="multi-cd",
)
multi-cd/adapter_config.json: 0%| | 0.00/945 [00:00<?, ?B/s]
adapter_model.safetensors: 0%| | 0.00/269M [00:00<?, ?B/s]
pipe.unet = loaded_cm_unet.eval().to(torch.float16)
assert loaded_cm_unet.active_adapter == 'multi-cd'
guidance_scale = 1
for prompt in validation_prompts:
generator = torch.Generator(device="cuda").manual_seed(1)
images = pipe(
prompt=prompt,
num_inference_steps=4,
num_images_per_prompt=4,
generator=generator,
guidance_scale=guidance_scale,
).images
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
На этом все! Ура!